Skip to content

Commit c311a4c

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
Enable mixed frame raysampling
Summary: Changed ray_sampler and metrics to be able to use mixed frame raysampling. Ray_sampler now has a new member which it passes to the pytorch3d raysampler. If the raybundle is heterogeneous metrics now samples images by padding xys first. This reduces memory consumption. Reviewed By: bottler, kjchalup Differential Revision: D39542221 fbshipit-source-id: a6fec23838d3049ae5c2fd2e1f641c46c7c927e3
1 parent ad8907d commit c311a4c

File tree

8 files changed

+102
-35
lines changed

8 files changed

+102
-35
lines changed

projects/implicitron_trainer/experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def run(self) -> None:
222222
train_loader=train_loader,
223223
val_loader=val_loader,
224224
test_loader=test_loader,
225+
train_dataset=datasets.train,
225226
model=model,
226227
optimizer=optimizer,
227228
scheduler=scheduler,

projects/implicitron_trainer/tests/experiment.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ model_factory_ImplicitronModelFactory_args:
197197
n_pts_per_ray_training: 64
198198
n_pts_per_ray_evaluation: 64
199199
n_rays_per_image_sampled_from_mask: 1024
200+
n_rays_total_training: null
200201
stratified_point_sampling_training: true
201202
stratified_point_sampling_evaluation: false
202203
scene_extent: 8.0
@@ -208,6 +209,7 @@ model_factory_ImplicitronModelFactory_args:
208209
n_pts_per_ray_training: 64
209210
n_pts_per_ray_evaluation: 64
210211
n_rays_per_image_sampled_from_mask: 1024
212+
n_rays_total_training: null
211213
stratified_point_sampling_training: true
212214
stratified_point_sampling_evaluation: false
213215
min_depth: 0.1

pytorch3d/implicitron/models/generic_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def curried_viewpooler(pts):
473473
self.view_metrics(
474474
results=preds,
475475
raymarched=rendered,
476-
xys=ray_bundle.xys,
476+
ray_bundle=ray_bundle,
477477
image_rgb=safe_slice_targets(image_rgb),
478478
depth_map=safe_slice_targets(depth_map),
479479
fg_probability=safe_slice_targets(fg_probability),
@@ -932,6 +932,11 @@ def _chunk_generator(
932932
if len(iter) >= tqdm_trigger_threshold:
933933
iter = tqdm.tqdm(iter)
934934

935+
def _safe_slice(
936+
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
937+
) -> Optional[torch.Tensor]:
938+
return tensor[start_idx:end_idx] if tensor is not None else None
939+
935940
for start_idx in iter:
936941
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
937942
ray_bundle_chunk = ImplicitronRayBundle(
@@ -943,6 +948,8 @@ def _chunk_generator(
943948
:, start_idx:end_idx
944949
],
945950
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
951+
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
952+
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
946953
)
947954
extra_args = kwargs.copy()
948955
for k, v in chunked_inputs.items():

pytorch3d/implicitron/models/metrics.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from typing import Any, Dict, Optional
1010

1111
import torch
12+
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
1213
from pytorch3d.implicitron.tools import metric_utils as utils
1314
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
15+
from pytorch3d.ops import packed_to_padded, padded_to_packed
1416
from pytorch3d.renderer import utils as rend_utils
1517

1618
from .renderer.base import RendererOutput
@@ -60,7 +62,7 @@ def __post_init__(self) -> None:
6062
def forward(
6163
self,
6264
raymarched: RendererOutput,
63-
xys: torch.Tensor,
65+
ray_bundle: ImplicitronRayBundle,
6466
image_rgb: Optional[torch.Tensor] = None,
6567
depth_map: Optional[torch.Tensor] = None,
6668
fg_probability: Optional[torch.Tensor] = None,
@@ -79,10 +81,8 @@ def forward(
7981
names of the output metrics `metric_name_i` with their corresponding
8082
values `metric_value_i` represented as 0-dimensional float tensors.
8183
raymarched: Output of the renderer.
82-
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
83-
the predictions are defined. All ground truth inputs are sampled at
84-
these locations in order to extract values that correspond to the
85-
predictions.
84+
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
85+
object
8686
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
8787
values.
8888
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
@@ -141,7 +141,7 @@ class ViewMetrics(ViewMetricsBase):
141141
def forward(
142142
self,
143143
raymarched: RendererOutput,
144-
xys: torch.Tensor,
144+
ray_bundle: ImplicitronRayBundle,
145145
image_rgb: Optional[torch.Tensor] = None,
146146
depth_map: Optional[torch.Tensor] = None,
147147
fg_probability: Optional[torch.Tensor] = None,
@@ -165,10 +165,8 @@ def forward(
165165
input 3D coordinates used to compute the eikonal loss.
166166
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
167167
containing a `Hg x Wg x Dg` voxel grid of density values.
168-
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
169-
the predictions are defined. All ground truth inputs are sampled at
170-
these locations in order to extract values that correspond to the
171-
predictions.
168+
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
169+
object
172170
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
173171
values.
174172
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
@@ -209,7 +207,7 @@ def forward(
209207
"""
210208
metrics = self._calculate_stage(
211209
raymarched,
212-
xys,
210+
ray_bundle,
213211
image_rgb,
214212
depth_map,
215213
fg_probability,
@@ -221,7 +219,7 @@ def forward(
221219
metrics.update(
222220
self(
223221
raymarched.prev_stage,
224-
xys,
222+
ray_bundle,
225223
image_rgb,
226224
depth_map,
227225
fg_probability,
@@ -235,7 +233,7 @@ def forward(
235233
def _calculate_stage(
236234
self,
237235
raymarched: RendererOutput,
238-
xys: torch.Tensor,
236+
ray_bundle: ImplicitronRayBundle,
239237
image_rgb: Optional[torch.Tensor] = None,
240238
depth_map: Optional[torch.Tensor] = None,
241239
fg_probability: Optional[torch.Tensor] = None,
@@ -253,6 +251,27 @@ def _calculate_stage(
253251
_reshape_nongrid_var(x)
254252
for x in [raymarched.features, raymarched.masks, raymarched.depths]
255253
]
254+
xys = ray_bundle.xys
255+
256+
# If ray_bundle is packed than we can sample images in padded state to lower
257+
# memory requirements. Instead of having one image for every element in
258+
# ray_bundle we can than have one image per unique sampled camera.
259+
if ray_bundle.is_packed():
260+
# pyre-ignore[6]
261+
cumsum = torch.cumsum(ray_bundle.camera_counts, dim=0, dtype=torch.long)
262+
first_idxs = torch.cat(
263+
(
264+
# pyre-ignore[16]
265+
ray_bundle.camera_counts.new_zeros((1,), dtype=torch.long),
266+
cumsum[:-1],
267+
)
268+
)
269+
# pyre-ignore[16]
270+
num_inputs = int(ray_bundle.camera_counts.sum())
271+
# pyre-ignore[6]
272+
max_size = int(torch.max(ray_bundle.camera_counts))
273+
xys = packed_to_padded(xys, first_idxs, max_size)
274+
256275
# reshape the sampling grid as well
257276
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
258277
# now that we use rend_utils.ndc_grid_sample
@@ -262,7 +281,20 @@ def _calculate_stage(
262281
def sample(tensor, mode):
263282
if tensor is None:
264283
return tensor
265-
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
284+
if ray_bundle.is_packed():
285+
# select images that corespond to sampled cameras if raybundle is packed
286+
tensor = tensor[ray_bundle.camera_ids]
287+
result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
288+
if ray_bundle.is_packed():
289+
# Images after sampling are in a form [batch, 3, max_num_rays, 1],
290+
# packed_to_padded combines first two dimensions so we need to swap 1st
291+
# and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1]
292+
# (we use keepdim=True).
293+
result = result.transpose(1, 2)
294+
result = padded_to_packed(result, first_idxs, num_inputs)[:, None]
295+
result = result.transpose(1, 2)
296+
297+
return result
266298

267299
# eval all results in this size
268300
image_rgb = sample(image_rgb, mode="bilinear")

pytorch3d/implicitron/models/renderer/ray_point_refiner.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
89
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
9-
from pytorch3d.renderer import RayBundle
10+
1011
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
1112

1213

@@ -42,21 +43,21 @@ def __post_init__(self) -> None:
4243

4344
def forward(
4445
self,
45-
input_ray_bundle: RayBundle,
46+
input_ray_bundle: ImplicitronRayBundle,
4647
ray_weights: torch.Tensor,
4748
**kwargs,
48-
) -> RayBundle:
49+
) -> ImplicitronRayBundle:
4950
"""
5051
Args:
51-
input_ray_bundle: An instance of `RayBundle` specifying the
52+
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
5253
source rays for sampling of the probability distribution.
5354
ray_weights: A tensor of shape
5455
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
5556
elements defining the probability distribution to sample
5657
ray points from.
5758
5859
Returns:
59-
ray_bundle: A new `RayBundle` instance containing the input ray
60+
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
6061
points together with `n_pts_per_ray` additionally sampled
6162
points per ray. For each ray, the lengths are sorted.
6263
"""
@@ -79,9 +80,6 @@ def forward(
7980
# Resort by depth.
8081
z_vals, _ = torch.sort(z_vals, dim=-1)
8182

82-
return RayBundle(
83-
origins=input_ray_bundle.origins,
84-
directions=input_ray_bundle.directions,
85-
lengths=z_vals,
86-
xys=input_ray_bundle.xys,
87-
)
83+
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
84+
new_bundle.lengths = z_vals
85+
return new_bundle

pytorch3d/implicitron/models/renderer/ray_sampler.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,17 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
7272
sampling_mode_evaluation: Same as above but for evaluation.
7373
n_pts_per_ray_training: The number of points sampled along each ray during training.
7474
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
75-
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
75+
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image
76+
grid. Given a batch of image grids, this many is sampled from each.
77+
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
78+
defined.
79+
n_rays_total_training: (optional) How many rays in total to sample from the entire
80+
batch of provided image grid. The result is as if `n_rays_total_training`
81+
cameras/image grids were sampled with replacement from the cameras / image grids
82+
provided and for every camera one ray was sampled.
83+
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
84+
defined, to use you have to set `n_rays_per_image` to None.
85+
Used only for EvaluationMode.TRAINING.
7686
stratified_point_sampling_training: if set, performs stratified random sampling
7787
along the ray; otherwise takes ray points at deterministic offsets.
7888
stratified_point_sampling_evaluation: Same as above but for evaluation.
@@ -85,14 +95,23 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
8595
sampling_mode_evaluation: str = "full_grid"
8696
n_pts_per_ray_training: int = 64
8797
n_pts_per_ray_evaluation: int = 64
88-
n_rays_per_image_sampled_from_mask: int = 1024
98+
n_rays_per_image_sampled_from_mask: Optional[int] = 1024
99+
n_rays_total_training: Optional[int] = None
89100
# stratified sampling vs taking points at deterministic offsets
90101
stratified_point_sampling_training: bool = True
91102
stratified_point_sampling_evaluation: bool = False
92103

93104
def __post_init__(self):
94105
super().__init__()
95106

107+
if (self.n_rays_per_image_sampled_from_mask is not None) and (
108+
self.n_rays_total_training is not None
109+
):
110+
raise ValueError(
111+
"Cannot both define n_rays_total_training and "
112+
"n_rays_per_image_sampled_from_mask."
113+
)
114+
96115
self._sampling_mode = {
97116
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
98117
EvaluationMode.EVALUATION: RenderSamplingMode(
@@ -110,9 +129,11 @@ def __post_init__(self):
110129
if self._sampling_mode[EvaluationMode.TRAINING]
111130
== RenderSamplingMode.MASK_SAMPLE
112131
else None,
132+
n_rays_total=self.n_rays_total_training,
113133
unit_directions=True,
114134
stratified_sampling=self.stratified_point_sampling_training,
115135
)
136+
116137
self._evaluation_raysampler = NDCMultinomialRaysampler(
117138
image_width=self.image_width,
118139
image_height=self.image_height,

pytorch3d/renderer/implicit/raysampling.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,12 @@ def __init__(
9090
min_depth: The minimum depth of a ray-point.
9191
max_depth: The maximum depth of a ray-point.
9292
n_rays_per_image: If given, this amount of rays are sampled from the grid.
93+
`n_rays_per_image` and `n_rays_total` cannot both be defined.
9394
n_rays_total: How many rays in total to sample from the cameras provided. The result
94-
is as if `n_rays_total` cameras were sampled with replacement from the
95-
cameras provided and for every camera one ray was sampled. If set, this disables
96-
`n_rays_per_image` and returns the HeterogeneousRayBundle with
97-
batch_size=n_rays_total.
95+
is as if `n_rays_total_training` cameras were sampled with replacement from the
96+
cameras provided and for every camera one ray was sampled. If set returns the
97+
HeterogeneousRayBundle with batch_size=n_rays_total.
98+
`n_rays_per_image` and `n_rays_total` cannot both be defined.
9899
unit_directions: whether to normalize direction vectors in ray bundle.
99100
stratified_sampling: if True, performs stratified random sampling
100101
along the ray; otherwise takes ray points at deterministic offsets.
@@ -144,13 +145,15 @@ def forward(
144145
min_depth: The minimum depth of a ray-point.
145146
max_depth: The maximum depth of a ray-point.
146147
n_rays_per_image: If given, this amount of rays are sampled from the grid.
148+
`n_rays_per_image` and `n_rays_total` cannot both be defined.
147149
n_pts_per_ray: The number of points sampled along each ray.
148150
stratified_sampling: if set, overrides stratified_sampling provided
149151
in __init__.
150152
n_rays_total: How many rays in total to sample from the cameras provided. The result
151153
is as if `n_rays_total_training` cameras were sampled with replacement from the
152-
cameras provided and for every camera one ray was sampled. If set, returns the
154+
cameras provided and for every camera one ray was sampled. If set returns the
153155
HeterogeneousRayBundle with batch_size=n_rays_total.
156+
`n_rays_per_image` and `n_rays_total` cannot both be defined.
154157
Returns:
155158
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
156159
following fields:
@@ -352,13 +355,15 @@ def __init__(
352355
min_y: The smallest y-coordinate of each ray's source pixel.
353356
max_y: The largest y-coordinate of each ray's source pixel.
354357
n_rays_per_image: The number of rays randomly sampled in each camera.
358+
`n_rays_per_image` and `n_rays_total` cannot both be defined.
355359
n_pts_per_ray: The number of points sampled along each ray.
356360
min_depth: The minimum depth of each ray-point.
357361
max_depth: The maximum depth of each ray-point.
358362
n_rays_total: How many rays in total to sample from the cameras provided. The result
359363
is as if `n_rays_total_training` cameras were sampled with replacement from the
360-
cameras provided and for every camera one ray was sampled. If set, this returns
361-
the HeterogeneousRayBundleyBundle with batch_size=n_rays_total.
364+
cameras provided and for every camera one ray was sampled. If set returns the
365+
HeterogeneousRayBundle with batch_size=n_rays_total.
366+
`n_rays_per_image` and `n_rays_total` cannot both be defined.
362367
unit_directions: whether to normalize direction vectors in ray bundle.
363368
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
364369
bins for each ray; otherwise takes n_pts_per_ray deterministic points

tests/implicitron/data/overrides.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ raysampler_AdaptiveRaySampler_args:
5959
n_pts_per_ray_training: 64
6060
n_pts_per_ray_evaluation: 64
6161
n_rays_per_image_sampled_from_mask: 1024
62+
n_rays_total_training: null
6263
stratified_point_sampling_training: true
6364
stratified_point_sampling_evaluation: false
6465
scene_extent: 8.0

0 commit comments

Comments
 (0)