Skip to content

Commit 11959e0

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Subsets in dataset iterators
Summary: For the new API, filtering iterators over sequences by subsets is quite helpful. The change is backwards compatible. Reviewed By: bottler Differential Revision: D42739669 fbshipit-source-id: d150a404aeaf42fd04a81304c63a4cba203f897d
1 parent 54eb76d commit 11959e0

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

pytorch3d/implicitron/dataset/dataset_base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def __len__(self) -> int:
237237
raise NotImplementedError()
238238

239239
def get_frame_numbers_and_timestamps(
240-
self, idxs: Sequence[int]
240+
self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None
241241
) -> List[Tuple[int, float]]:
242242
"""
243243
If the sequences in the dataset are videos rather than
@@ -251,7 +251,9 @@ def get_frame_numbers_and_timestamps(
251251
frames.
252252
253253
Args:
254-
idx: frame index in self
254+
idxs: frame index in self
255+
subset_filter: If given, an index in idxs is ignored if the
256+
corresponding frame is not in any of the named subsets.
255257
256258
Returns:
257259
tuple of
@@ -291,7 +293,7 @@ def category_to_sequence_names(self) -> Dict[str, List[str]]:
291293
return dict(c2seq)
292294

293295
def sequence_frames_in_order(
294-
self, seq_name: str
296+
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
295297
) -> Iterator[Tuple[float, int, int]]:
296298
"""Returns an iterator over the frame indices in a given sequence.
297299
We attempt to first sort by timestamp (if they are available),
@@ -308,7 +310,9 @@ def sequence_frames_in_order(
308310
"""
309311
# pyre-ignore[16]
310312
seq_frame_indices = self._seq_to_idx[seq_name]
311-
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
313+
nos_timestamps = self.get_frame_numbers_and_timestamps(
314+
seq_frame_indices, subset_filter
315+
)
312316

313317
yield from sorted(
314318
[
@@ -317,11 +321,13 @@ def sequence_frames_in_order(
317321
]
318322
)
319323

320-
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
324+
def sequence_indices_in_order(
325+
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
326+
) -> Iterator[int]:
321327
"""Same as `sequence_frames_in_order` but returns the iterator over
322328
only dataset indices.
323329
"""
324-
for _, _, idx in self.sequence_frames_in_order(seq_name):
330+
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
325331
yield idx
326332

327333
# frame_data_type is the actual type of frames returned by the dataset.

pytorch3d/implicitron/dataset/json_index_dataset.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,10 +888,16 @@ def _local_path(self, path: str) -> str:
888888
return self.path_manager.get_local_path(path)
889889

890890
def get_frame_numbers_and_timestamps(
891-
self, idxs: Sequence[int]
891+
self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None
892892
) -> List[Tuple[int, float]]:
893893
out: List[Tuple[int, float]] = []
894894
for idx in idxs:
895+
if (
896+
subset_filter is not None
897+
and self.frame_annots[idx]["subset"] not in subset_filter
898+
):
899+
continue
900+
895901
# pyre-ignore[16]
896902
frame_annotation = self.frame_annots[idx]["frame_annotation"]
897903
out.append(

tests/implicitron/test_data_json_index.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,41 @@ def test_loaders(self):
4040
self.assertEqual(len(data_sets.train), 81)
4141
self.assertEqual(len(data_sets.val), 102)
4242
self.assertEqual(len(data_sets.test), 102)
43+
44+
def test_visitor_subsets(self):
45+
args = get_default_args(ImplicitronDataSource)
46+
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
47+
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
48+
dataset_args.category = "skateboard"
49+
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
50+
dataset_args.test_restrict_sequence_id = 0
51+
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 1
52+
53+
data_source = ImplicitronDataSource(**args)
54+
datasets, _ = data_source.get_datasets_and_dataloaders()
55+
dataset = datasets.test
56+
57+
sequences = list(dataset.sequence_names())
58+
self.assertEqual(len(sequences), 1)
59+
i = 0
60+
for seq in sequences:
61+
last_ts = float("-Inf")
62+
seq_frames = list(dataset.sequence_frames_in_order(seq))
63+
self.assertEqual(len(seq_frames), 102)
64+
for ts, _, idx in seq_frames:
65+
self.assertEqual(i, idx)
66+
i += 1
67+
self.assertGreaterEqual(ts, last_ts)
68+
last_ts = ts
69+
70+
last_ts = float("-Inf")
71+
known_frames = list(dataset.sequence_frames_in_order(seq, "test_known"))
72+
self.assertEqual(len(known_frames), 81)
73+
for ts, _, _ in known_frames:
74+
self.assertGreaterEqual(ts, last_ts)
75+
last_ts = ts
76+
77+
known_indices = list(dataset.sequence_indices_in_order(seq, "test_known"))
78+
self.assertEqual(len(known_indices), 81)
79+
80+
break # testing only the first sequence

0 commit comments

Comments
 (0)