Skip to content

Commit 6ae863f

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
Heterogeneous raysampling -> RayBundleHeterogeneous
Summary: Added heterogeneous raysampling to pytorch3d raysampler, different cameras are sampled different number of times. It now returns RayBundle if heterogeneous raysampling is off and new RayBundleHeterogeneous (with added fields `camera_ids` and `camera_counts`). Heterogeneous raysampling is on if `n_rays_total` is not None. Reviewed By: bottler Differential Revision: D39542222 fbshipit-source-id: d3d88d822ec7696e856007c088dc36a1cfa8c625
1 parent 9a0f9ae commit 6ae863f

File tree

6 files changed

+325
-48
lines changed

6 files changed

+325
-48
lines changed

pytorch3d/renderer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
EmissionAbsorptionRaymarcher,
3232
GridRaysampler,
3333
HarmonicEmbedding,
34+
HeterogeneousRayBundle,
3435
ImplicitRenderer,
3536
MonteCarloRaysampler,
3637
MultinomialRaysampler,

pytorch3d/renderer/implicit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
1717
from .utils import (
18+
HeterogeneousRayBundle,
1819
ray_bundle_to_ray_points,
1920
ray_bundle_variables_to_ray_points,
2021
RayBundle,

pytorch3d/renderer/implicit/raysampling.py

Lines changed: 189 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import warnings
8-
from typing import Optional
8+
from typing import Optional, Tuple, Union
99

1010
import torch
1111
from pytorch3d.common.compat import meshgrid_ij
12+
from pytorch3d.ops import padded_to_packed
1213
from pytorch3d.renderer.cameras import CamerasBase
13-
from pytorch3d.renderer.implicit.utils import RayBundle
14+
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle, RayBundle
1415
from torch.nn import functional as F
1516

1617

@@ -73,6 +74,7 @@ def __init__(
7374
min_depth: float,
7475
max_depth: float,
7576
n_rays_per_image: Optional[int] = None,
77+
n_rays_total: Optional[int] = None,
7678
unit_directions: bool = False,
7779
stratified_sampling: bool = False,
7880
) -> None:
@@ -88,6 +90,11 @@ def __init__(
8890
min_depth: The minimum depth of a ray-point.
8991
max_depth: The maximum depth of a ray-point.
9092
n_rays_per_image: If given, this amount of rays are sampled from the grid.
93+
n_rays_total: How many rays in total to sample from the cameras provided. The result
94+
is as if `n_rays_total` cameras were sampled with replacement from the
95+
cameras provided and for every camera one ray was sampled. If set, this disables
96+
`n_rays_per_image` and returns the HeterogeneousRayBundle with
97+
batch_size=n_rays_total.
9198
unit_directions: whether to normalize direction vectors in ray bundle.
9299
stratified_sampling: if True, performs stratified random sampling
93100
along the ray; otherwise takes ray points at deterministic offsets.
@@ -97,6 +104,7 @@ def __init__(
97104
self._min_depth = min_depth
98105
self._max_depth = max_depth
99106
self._n_rays_per_image = n_rays_per_image
107+
self._n_rays_total = n_rays_total
100108
self._unit_directions = unit_directions
101109
self._stratified_sampling = stratified_sampling
102110

@@ -125,8 +133,9 @@ def forward(
125133
n_rays_per_image: Optional[int] = None,
126134
n_pts_per_ray: Optional[int] = None,
127135
stratified_sampling: Optional[bool] = None,
136+
n_rays_total: Optional[int] = None,
128137
**kwargs,
129-
) -> RayBundle:
138+
) -> Union[RayBundle, HeterogeneousRayBundle]:
130139
"""
131140
Args:
132141
cameras: A batch of `batch_size` cameras from which the rays are emitted.
@@ -138,8 +147,15 @@ def forward(
138147
n_pts_per_ray: The number of points sampled along each ray.
139148
stratified_sampling: if set, overrides stratified_sampling provided
140149
in __init__.
150+
n_rays_total: How many rays in total to sample from the cameras provided. The result
151+
is as if `n_rays_total_training` cameras were sampled with replacement from the
152+
cameras provided and for every camera one ray was sampled. If set, this disables
153+
`n_rays_per_image` and returns the HeterogeneousRayBundle with
154+
batch_size=n_rays_total.
141155
Returns:
142-
A named tuple RayBundle with the following fields:
156+
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
157+
following fields:
158+
143159
origins: A tensor of shape
144160
`(batch_size, s1, s2, 3)`
145161
denoting the locations of ray origins in the world coordinates.
@@ -153,23 +169,56 @@ def forward(
153169
`(batch_size, s1, s2, 2)`
154170
containing the 2D image coordinates of each ray or,
155171
if mask is given, `(batch_size, n, 1, 2)`
156-
Here `s1, s2` refer to spatial dimensions. Unless the mask is
157-
given, they equal `(image_height, image_width)`, otherwise `(n, 1)`,
158-
where `n` is `n_rays_per_image` if provided, otherwise the minimum
159-
cardinality of the mask in the batch.
172+
Here `s1, s2` refer to spatial dimensions.
173+
`(s1, s2)` refer to (highest priority first):
174+
- `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total)
175+
- `(n_rays_per_image, 1) if `n_rays_per_image` if provided,
176+
- `(n, 1)` where n is the minimum cardinality of the mask
177+
in the batch if `mask` is provided
178+
- `(image_height, image_width)` if nothing from above is satisfied
179+
180+
`HeterogeneousRayBundle` has additional members:
181+
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
182+
cameras. It represents unique ids of sampled cameras.
183+
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
184+
cameras. Represents how many times each camera from `camera_ids` was sampled
185+
186+
`HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle`
187+
is returned.
160188
"""
189+
n_rays_total = n_rays_total or self._n_rays_total
190+
n_rays_per_image = n_rays_per_image or self._n_rays_per_image
191+
assert (n_rays_total is None) or (
192+
n_rays_per_image is None
193+
), "`n_rays_total` and `n_rays_per_image` cannot both be defined."
194+
if n_rays_total:
195+
(
196+
cameras,
197+
mask,
198+
camera_ids, # unique ids of sampled cameras
199+
camera_counts, # number of times unique camera id was sampled
200+
# `n_rays_per_image` is equal to the max number of times a simgle camera
201+
# was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times
202+
# and then discard the unneeded rays.
203+
# pyre-ignore[9]
204+
n_rays_per_image,
205+
) = _sample_cameras_and_masks(n_rays_total, cameras, mask)
206+
else:
207+
camera_ids = torch.range(0, len(cameras), dtype=torch.long)
208+
161209
batch_size = cameras.R.shape[0]
162210
device = cameras.device
163211

164212
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
165213
xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1)
166214

167-
num_rays = n_rays_per_image or self._n_rays_per_image
168-
if mask is not None and num_rays is None:
215+
if mask is not None and n_rays_per_image is None:
169216
# if num rays not given, sample according to the smallest mask
170-
num_rays = num_rays or mask.sum(dim=(1, 2)).min().int().item()
217+
n_rays_per_image = (
218+
n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item()
219+
)
171220

172-
if num_rays is not None:
221+
if n_rays_per_image is not None:
173222
if mask is not None:
174223
assert mask.shape == xy_grid.shape[:3]
175224
weights = mask.reshape(batch_size, -1)
@@ -181,7 +230,9 @@ def forward(
181230
weights = xy_grid.new_ones(batch_size, width * height)
182231
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
183232
# float, int]`.
184-
rays_idx = _safe_multinomial(weights, num_rays)[..., None].expand(-1, -1, 2)
233+
rays_idx = _safe_multinomial(weights, n_rays_per_image)[..., None].expand(
234+
-1, -1, 2
235+
)
185236

186237
xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[
187238
:, :, None
@@ -198,7 +249,7 @@ def forward(
198249
else self._stratified_sampling
199250
)
200251

201-
return _xy_to_ray_bundle(
252+
ray_bundle = _xy_to_ray_bundle(
202253
cameras,
203254
xy_grid,
204255
min_depth,
@@ -208,6 +259,13 @@ def forward(
208259
stratified_sampling,
209260
)
210261

262+
return (
263+
# pyre-ignore[61]
264+
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
265+
if n_rays_total
266+
else ray_bundle
267+
)
268+
211269

212270
class NDCMultinomialRaysampler(MultinomialRaysampler):
213271
"""
@@ -231,6 +289,7 @@ def __init__(
231289
min_depth: float,
232290
max_depth: float,
233291
n_rays_per_image: Optional[int] = None,
292+
n_rays_total: Optional[int] = None,
234293
unit_directions: bool = False,
235294
stratified_sampling: bool = False,
236295
) -> None:
@@ -254,6 +313,7 @@ def __init__(
254313
min_depth=min_depth,
255314
max_depth=max_depth,
256315
n_rays_per_image=n_rays_per_image,
316+
n_rays_total=n_rays_total,
257317
unit_directions=unit_directions,
258318
stratified_sampling=stratified_sampling,
259319
)
@@ -281,6 +341,7 @@ def __init__(
281341
min_depth: float,
282342
max_depth: float,
283343
*,
344+
n_rays_total: Optional[int] = None,
284345
unit_directions: bool = False,
285346
stratified_sampling: bool = False,
286347
) -> None:
@@ -294,6 +355,11 @@ def __init__(
294355
n_pts_per_ray: The number of points sampled along each ray.
295356
min_depth: The minimum depth of each ray-point.
296357
max_depth: The maximum depth of each ray-point.
358+
n_rays_total: How many rays in total to sample from the cameras provided. The result
359+
is as if `n_rays_total_training` cameras were sampled with replacement from the
360+
cameras provided and for every camera one ray was sampled. If set, this disables
361+
`n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with
362+
batch_size=n_rays_total.
297363
unit_directions: whether to normalize direction vectors in ray bundle.
298364
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
299365
bins for each ray; otherwise takes n_pts_per_ray deterministic points
@@ -308,6 +374,7 @@ def __init__(
308374
self._n_pts_per_ray = n_pts_per_ray
309375
self._min_depth = min_depth
310376
self._max_depth = max_depth
377+
self._n_rays_total = n_rays_total
311378
self._unit_directions = unit_directions
312379
self._stratified_sampling = stratified_sampling
313380

@@ -317,15 +384,16 @@ def forward(
317384
*,
318385
stratified_sampling: Optional[bool] = None,
319386
**kwargs,
320-
) -> RayBundle:
387+
) -> Union[RayBundle, HeterogeneousRayBundle]:
321388
"""
322389
Args:
323390
cameras: A batch of `batch_size` cameras from which the rays are emitted.
324391
stratified_sampling: if set, overrides stratified_sampling provided
325392
in __init__.
326-
327393
Returns:
328-
A named tuple RayBundle with the following fields:
394+
A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the
395+
following fields:
396+
329397
origins: A tensor of shape
330398
`(batch_size, n_rays_per_image, 3)`
331399
denoting the locations of ray origins in the world coordinates.
@@ -338,7 +406,31 @@ def forward(
338406
xys: A tensor of shape
339407
`(batch_size, n_rays_per_image, 2)`
340408
containing the 2D image coordinates of each ray.
409+
If `n_rays_total` is provided `batch_size=n_rays_total`and
410+
`n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle`
411+
is returned.
412+
413+
`HeterogeneousRayBundle` has additional members:
414+
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
415+
cameras. It represents unique ids of sampled cameras.
416+
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
417+
cameras. Represents how many times each camera from `camera_ids` was sampled
341418
"""
419+
assert (self._n_rays_total is None) or (
420+
self._n_rays_per_image is None
421+
), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined."
422+
423+
if self._n_rays_total:
424+
(
425+
cameras,
426+
_,
427+
camera_ids,
428+
camera_counts,
429+
n_rays_per_image,
430+
) = _sample_cameras_and_masks(self._n_rays_total, cameras, None)
431+
else:
432+
camera_ids = torch.range(0, len(cameras), dtype=torch.long)
433+
n_rays_per_image = self._n_rays_per_image
342434

343435
batch_size = cameras.R.shape[0]
344436

@@ -349,7 +441,7 @@ def forward(
349441
rays_xy = torch.cat(
350442
[
351443
torch.rand(
352-
size=(batch_size, self._n_rays_per_image, 1),
444+
size=(batch_size, n_rays_per_image, 1),
353445
dtype=torch.float32,
354446
device=device,
355447
)
@@ -369,7 +461,7 @@ def forward(
369461
else self._stratified_sampling
370462
)
371463

372-
return _xy_to_ray_bundle(
464+
ray_bundle = _xy_to_ray_bundle(
373465
cameras,
374466
rays_xy,
375467
self._min_depth,
@@ -379,6 +471,13 @@ def forward(
379471
stratified_sampling,
380472
)
381473

474+
return (
475+
# pyre-ignore[61]
476+
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
477+
if self._n_rays_total
478+
else ray_bundle
479+
)
480+
382481

383482
# Settings for backwards compatibility
384483
def GridRaysampler(
@@ -602,3 +701,74 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
602701
# Samples in those intervals.
603702
jiggled = lower + (upper - lower) * torch.rand_like(lower)
604703
return jiggled
704+
705+
706+
def _sample_cameras_and_masks(
707+
n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None
708+
) -> Tuple[
709+
CamerasBase, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor
710+
]:
711+
"""
712+
Samples n_rays_total cameras and masks and returns them in a form
713+
(camera_idx, count), where count represents number of times the same camera
714+
has been sampled.
715+
716+
Args:
717+
n_samples: how many camera and mask pairs to sample
718+
cameras: A batch of `batch_size` cameras from which the rays are emitted.
719+
mask: Optional. Should be of size (batch_size, image_height, image_width).
720+
Returns:
721+
tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids,
722+
number_of_times_each_sampled_camera_has_been_sampled,
723+
max_number_of_times_camera_has_been_sampled,
724+
)
725+
"""
726+
sampled_ids = torch.randint(
727+
0,
728+
len(cameras),
729+
size=(n_samples,),
730+
dtype=torch.long,
731+
)
732+
unique_ids, counts = torch.unique(sampled_ids, return_counts=True)
733+
return (
734+
cameras[unique_ids],
735+
mask[unique_ids] if mask is not None else None,
736+
unique_ids,
737+
counts,
738+
torch.max(counts),
739+
)
740+
741+
742+
def _pack_ray_bundle(
743+
ray_bundle: RayBundle, camera_ids: torch.Tensor, camera_counts: torch.Tensor
744+
) -> HeterogeneousRayBundle:
745+
"""
746+
Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
747+
[total_num_rays, 1, ...]
748+
749+
Args:
750+
ray_bundle: A ray_bundle to pack
751+
camera_ids: Unique ids of cameras that were sampled
752+
camera_counts: how many of which camera to pack, each count coresponds to
753+
one 'row' of the ray_bundle and says how many rays wll be taken
754+
from it and packed.
755+
Returns:
756+
HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
757+
"""
758+
camera_counts = camera_counts.to(ray_bundle.origins.device)
759+
cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
760+
first_idxs = torch.cat(
761+
(camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
762+
)
763+
num_inputs = int(camera_counts.sum())
764+
765+
return HeterogeneousRayBundle(
766+
origins=padded_to_packed(ray_bundle.origins, first_idxs, num_inputs)[:, None],
767+
directions=padded_to_packed(ray_bundle.directions, first_idxs, num_inputs)[
768+
:, None
769+
],
770+
lengths=padded_to_packed(ray_bundle.lengths, first_idxs, num_inputs)[:, None],
771+
xys=padded_to_packed(ray_bundle.xys, first_idxs, num_inputs)[:, None],
772+
camera_ids=camera_ids,
773+
camera_counts=camera_counts,
774+
)

0 commit comments

Comments
 (0)