Skip to content

Commit b046259

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Refactor: FrameDataBuilder is more extensible.
Summary: This is mostly a refactoring diff to reduce friction in extending the frame data. Slight functional changes: dataset getitem now accepts (seq_name, frame_number_as_singleton_tensor) as a non-advertised feature. Otherwise this code crashes: ``` item = dataset[0] dataset[item.sequence_name, item.frame_number] ``` Reviewed By: bottler Differential Revision: D45780175 fbshipit-source-id: 75b8e8d3dabed954a804310abdbd8ab44a8dea29
1 parent d08fe6d commit b046259

File tree

5 files changed

+102
-41
lines changed

5 files changed

+102
-41
lines changed

projects/implicitron_trainer/tests/test_experiment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ def test_yaml_contents(self):
132132
# Check that the default config values, defined by Experiment and its
133133
# members, is what we expect it to be.
134134
cfg = OmegaConf.structured(experiment.Experiment)
135+
# the following removes the possible effect of env variables
136+
ds_arg = cfg.data_source_ImplicitronDataSource_args
137+
ds_arg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = ""
138+
ds_arg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = ""
139+
cfg.training_loop_ImplicitronTrainingLoop_args.visdom_port = 8097
135140
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
136141
if DEBUG:
137142
(DATA_DIR / "experiment.yaml").write_text(yaml)

pytorch3d/implicitron/dataset/frame_data.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,10 @@ def crop_by_metadata_bbox_(
203203
when no image has been loaded)
204204
"""
205205
if self.bbox_xywh is None:
206-
raise ValueError("Attempted cropping by metadata with empty bounding box")
206+
raise ValueError(
207+
"Attempted cropping by metadata with empty bounding box. Consider either"
208+
" to remove_empty_masks or turn off box_crop in the dataset config."
209+
)
207210

208211
if not self._uncropped:
209212
raise ValueError(
@@ -528,12 +531,7 @@ def __post_init__(self) -> None:
528531
"Make sure it is set in either FrameDataBuilder or Dataset params."
529532
)
530533

531-
if self.path_manager is None:
532-
dataset_root_exists = os.path.isdir(self.dataset_root) # pyre-ignore
533-
else:
534-
dataset_root_exists = self.path_manager.isdir(self.dataset_root)
535-
536-
if load_any_blob and not dataset_root_exists:
534+
if load_any_blob and not self._exists_in_dataset_root(""):
537535
raise ValueError(
538536
f"dataset_root is passed but {self.dataset_root} does not exist."
539537
)
@@ -604,14 +602,27 @@ def build(
604602
frame_data.image_size_hw = image_size_hw # original image size
605603
# image size after crop/resize
606604
frame_data.effective_image_size_hw = image_size_hw
605+
image_path = None
606+
dataset_root = self.dataset_root
607+
if frame_annotation.image.path is not None and dataset_root is not None:
608+
image_path = os.path.join(dataset_root, frame_annotation.image.path)
609+
frame_data.image_path = image_path
607610

608611
if load_blobs and self.load_images:
609-
(
610-
frame_data.image_rgb,
611-
frame_data.image_path,
612-
) = self._load_images(frame_annotation, frame_data.fg_probability)
612+
if image_path is None:
613+
raise ValueError("Image path is required to load images.")
614+
615+
image_np = load_image(self._local_path(image_path))
616+
frame_data.image_rgb = self._postprocess_image(
617+
image_np, frame_annotation.image.size, frame_data.fg_probability
618+
)
613619

614-
if load_blobs and self.load_depths and frame_annotation.depth is not None:
620+
if (
621+
load_blobs
622+
and self.load_depths
623+
and frame_annotation.depth is not None
624+
and frame_annotation.depth.path is not None
625+
):
615626
(
616627
frame_data.depth_map,
617628
frame_data.depth_path,
@@ -652,44 +663,42 @@ def _load_fg_probability(
652663

653664
return fg_probability, full_path
654665

655-
def _load_images(
666+
def _postprocess_image(
656667
self,
657-
entry: types.FrameAnnotation,
668+
image_np: np.ndarray,
669+
image_size: Tuple[int, int],
658670
fg_probability: Optional[torch.Tensor],
659-
) -> Tuple[torch.Tensor, str]:
660-
assert self.dataset_root is not None and entry.image is not None
661-
path = os.path.join(self.dataset_root, entry.image.path)
662-
image_rgb = load_image(self._local_path(path))
671+
) -> torch.Tensor:
672+
image_rgb = safe_as_tensor(image_np, torch.float)
663673

664-
if image_rgb.shape[-2:] != entry.image.size:
665-
raise ValueError(
666-
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
667-
)
674+
if image_rgb.shape[-2:] != image_size:
675+
raise ValueError(f"bad image size: {image_rgb.shape[-2:]} vs {image_size}!")
668676

669677
if self.mask_images:
670678
assert fg_probability is not None
671679
image_rgb *= fg_probability
672680

673-
return image_rgb, path
681+
return image_rgb
674682

675683
def _load_mask_depth(
676684
self,
677685
entry: types.FrameAnnotation,
678686
fg_probability: Optional[torch.Tensor],
679687
) -> Tuple[torch.Tensor, str, torch.Tensor]:
680688
entry_depth = entry.depth
681-
assert self.dataset_root is not None and entry_depth is not None
682-
path = os.path.join(self.dataset_root, entry_depth.path)
689+
dataset_root = self.dataset_root
690+
assert dataset_root is not None
691+
assert entry_depth is not None and entry_depth.path is not None
692+
path = os.path.join(dataset_root, entry_depth.path)
683693
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
684694

685695
if self.mask_depths:
686696
assert fg_probability is not None
687697
depth_map *= fg_probability
688698

689-
if self.load_depth_masks:
690-
assert entry_depth.mask_path is not None
691-
# pyre-ignore
692-
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
699+
mask_path = entry_depth.mask_path
700+
if self.load_depth_masks and mask_path is not None:
701+
mask_path = os.path.join(dataset_root, mask_path)
693702
depth_mask = load_depth_mask(self._local_path(mask_path))
694703
else:
695704
depth_mask = torch.ones_like(depth_map)
@@ -745,6 +754,16 @@ def _local_path(self, path: str) -> str:
745754
return path
746755
return self.path_manager.get_local_path(path)
747756

757+
def _exists_in_dataset_root(self, relpath) -> bool:
758+
if not self.dataset_root:
759+
return False
760+
761+
full_path = os.path.join(self.dataset_root, relpath)
762+
if self.path_manager is None:
763+
return os.path.exists(full_path)
764+
else:
765+
return self.path_manager.exists(full_path)
766+
748767

749768
@registry.register
750769
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):

pytorch3d/implicitron/dataset/sql_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def _get_item(
210210
seq, frame = self._index.index[frame_idx]
211211
else:
212212
seq, frame, *rest = frame_idx
213+
if isinstance(frame, torch.LongTensor):
214+
frame = frame.item()
215+
213216
if (seq, frame) not in self._index.index:
214217
raise IndexError(
215218
f"Sequence-frame index {frame_idx} not found; was it filtered out?"

pytorch3d/implicitron/dataset/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,19 +225,23 @@ def resize_image(
225225
return imre_, minscale, mask
226226

227227

228+
def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
229+
im = np.atleast_3d(image).transpose((2, 0, 1))
230+
return im.astype(np.float32) / 255.0
231+
232+
228233
def load_image(path: str) -> np.ndarray:
229234
with Image.open(path) as pil_im:
230235
im = np.array(pil_im.convert("RGB"))
231-
im = im.transpose((2, 0, 1))
232-
im = im.astype(np.float32) / 255.0
233-
return im
236+
237+
return transpose_normalize_image(im)
234238

235239

236240
def load_mask(path: str) -> np.ndarray:
237241
with Image.open(path) as pil_im:
238242
mask = np.array(pil_im)
239-
mask = mask.astype(np.float32) / 255.0
240-
return mask[None] # fake feature channel
243+
244+
return transpose_normalize_image(mask)
241245

242246

243247
def load_depth(path: str, scale_adjustment: float) -> np.ndarray:

tests/implicitron/test_frame_data_builder.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
load_image,
2626
load_mask,
2727
safe_as_tensor,
28+
transpose_normalize_image,
2829
)
2930
from pytorch3d.implicitron.tools.config import get_default_args
3031
from pytorch3d.renderer.cameras import PerspectiveCameras
@@ -123,14 +124,15 @@ def test_load_and_adjust_frame_data(self):
123124
# assert bboxes shape
124125
self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
125126

126-
(
127-
self.frame_data.image_rgb,
128-
self.frame_data.image_path,
129-
) = self.frame_data_builder._load_images(
130-
self.frame_annotation, self.frame_data.fg_probability
127+
image_path = os.path.join(
128+
self.frame_data_builder.dataset_root, self.frame_annotation.image.path
129+
)
130+
image_np = load_image(self.frame_data_builder._local_path(image_path))
131+
self.assertIsInstance(image_np, np.ndarray)
132+
self.frame_data.image_rgb = self.frame_data_builder._postprocess_image(
133+
image_np, self.frame_annotation.image.size, self.frame_data.fg_probability
131134
)
132-
self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
133-
self.assertIsNotNone(self.frame_data.image_path)
135+
self.assertIsInstance(self.frame_data.image_rgb, torch.Tensor)
134136

135137
(
136138
self.frame_data.depth_map,
@@ -184,6 +186,34 @@ def test_load_and_adjust_frame_data(self):
184186
)
185187
self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)
186188

189+
def test_transpose_normalize_image(self):
190+
def inverse_transpose_normalize_image(image: np.ndarray) -> np.ndarray:
191+
im = image * 255.0
192+
return im.transpose((1, 2, 0)).astype(np.uint8)
193+
194+
# Test 2D input
195+
input_image = np.array(
196+
[[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=np.uint8
197+
)
198+
expected_input = inverse_transpose_normalize_image(
199+
transpose_normalize_image(input_image)
200+
)
201+
self.assertClose(input_image[..., None], expected_input)
202+
203+
# Test 3D input
204+
input_image = np.array(
205+
[
206+
[[10, 20, 30], [40, 50, 60], [70, 80, 90]],
207+
[[100, 110, 120], [130, 140, 150], [160, 170, 180]],
208+
[[190, 200, 210], [220, 230, 240], [250, 255, 255]],
209+
],
210+
dtype=np.uint8,
211+
)
212+
expected_input = inverse_transpose_normalize_image(
213+
transpose_normalize_image(input_image)
214+
)
215+
self.assertClose(input_image, expected_input)
216+
187217
def test_load_image(self):
188218
path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
189219
local_path = self.path_manager.get_local_path(path)

0 commit comments

Comments
 (0)