5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import warnings
8
- from typing import Optional
8
+ from typing import Optional , Tuple , Union
9
9
10
10
import torch
11
11
from pytorch3d .common .compat import meshgrid_ij
12
+ from pytorch3d .ops import padded_to_packed
12
13
from pytorch3d .renderer .cameras import CamerasBase
13
- from pytorch3d .renderer .implicit .utils import RayBundle
14
+ from pytorch3d .renderer .implicit .utils import HeterogeneousRayBundle , RayBundle
14
15
from torch .nn import functional as F
15
16
16
17
@@ -73,6 +74,7 @@ def __init__(
73
74
min_depth : float ,
74
75
max_depth : float ,
75
76
n_rays_per_image : Optional [int ] = None ,
77
+ n_rays_total : Optional [int ] = None ,
76
78
unit_directions : bool = False ,
77
79
stratified_sampling : bool = False ,
78
80
) -> None :
@@ -88,6 +90,11 @@ def __init__(
88
90
min_depth: The minimum depth of a ray-point.
89
91
max_depth: The maximum depth of a ray-point.
90
92
n_rays_per_image: If given, this amount of rays are sampled from the grid.
93
+ n_rays_total: How many rays in total to sample from the cameras provided. The result
94
+ is as if `n_rays_total` cameras were sampled with replacement from the
95
+ cameras provided and for every camera one ray was sampled. If set, this disables
96
+ `n_rays_per_image` and returns the HeterogeneousRayBundle with
97
+ batch_size=n_rays_total.
91
98
unit_directions: whether to normalize direction vectors in ray bundle.
92
99
stratified_sampling: if True, performs stratified random sampling
93
100
along the ray; otherwise takes ray points at deterministic offsets.
@@ -97,6 +104,7 @@ def __init__(
97
104
self ._min_depth = min_depth
98
105
self ._max_depth = max_depth
99
106
self ._n_rays_per_image = n_rays_per_image
107
+ self ._n_rays_total = n_rays_total
100
108
self ._unit_directions = unit_directions
101
109
self ._stratified_sampling = stratified_sampling
102
110
@@ -125,8 +133,9 @@ def forward(
125
133
n_rays_per_image : Optional [int ] = None ,
126
134
n_pts_per_ray : Optional [int ] = None ,
127
135
stratified_sampling : Optional [bool ] = None ,
136
+ n_rays_total : Optional [int ] = None ,
128
137
** kwargs ,
129
- ) -> RayBundle :
138
+ ) -> Union [ RayBundle , HeterogeneousRayBundle ] :
130
139
"""
131
140
Args:
132
141
cameras: A batch of `batch_size` cameras from which the rays are emitted.
@@ -138,8 +147,15 @@ def forward(
138
147
n_pts_per_ray: The number of points sampled along each ray.
139
148
stratified_sampling: if set, overrides stratified_sampling provided
140
149
in __init__.
150
+ n_rays_total: How many rays in total to sample from the cameras provided. The result
151
+ is as if `n_rays_total_training` cameras were sampled with replacement from the
152
+ cameras provided and for every camera one ray was sampled. If set, this disables
153
+ `n_rays_per_image` and returns the HeterogeneousRayBundle with
154
+ batch_size=n_rays_total.
141
155
Returns:
142
- A named tuple RayBundle with the following fields:
156
+ A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
157
+ following fields:
158
+
143
159
origins: A tensor of shape
144
160
`(batch_size, s1, s2, 3)`
145
161
denoting the locations of ray origins in the world coordinates.
@@ -153,23 +169,56 @@ def forward(
153
169
`(batch_size, s1, s2, 2)`
154
170
containing the 2D image coordinates of each ray or,
155
171
if mask is given, `(batch_size, n, 1, 2)`
156
- Here `s1, s2` refer to spatial dimensions. Unless the mask is
157
- given, they equal `(image_height, image_width)`, otherwise `(n, 1)`,
158
- where `n` is `n_rays_per_image` if provided, otherwise the minimum
159
- cardinality of the mask in the batch.
172
+ Here `s1, s2` refer to spatial dimensions.
173
+ `(s1, s2)` refer to (highest priority first):
174
+ - `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total)
175
+ - `(n_rays_per_image, 1) if `n_rays_per_image` if provided,
176
+ - `(n, 1)` where n is the minimum cardinality of the mask
177
+ in the batch if `mask` is provided
178
+ - `(image_height, image_width)` if nothing from above is satisfied
179
+
180
+ `HeterogeneousRayBundle` has additional members:
181
+ - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
182
+ cameras. It represents unique ids of sampled cameras.
183
+ - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
184
+ cameras. Represents how many times each camera from `camera_ids` was sampled
185
+
186
+ `HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle`
187
+ is returned.
160
188
"""
189
+ n_rays_total = n_rays_total or self ._n_rays_total
190
+ n_rays_per_image = n_rays_per_image or self ._n_rays_per_image
191
+ assert (n_rays_total is None ) or (
192
+ n_rays_per_image is None
193
+ ), "`n_rays_total` and `n_rays_per_image` cannot both be defined."
194
+ if n_rays_total :
195
+ (
196
+ cameras ,
197
+ mask ,
198
+ camera_ids , # unique ids of sampled cameras
199
+ camera_counts , # number of times unique camera id was sampled
200
+ # `n_rays_per_image` is equal to the max number of times a simgle camera
201
+ # was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times
202
+ # and then discard the unneeded rays.
203
+ # pyre-ignore[9]
204
+ n_rays_per_image ,
205
+ ) = _sample_cameras_and_masks (n_rays_total , cameras , mask )
206
+ else :
207
+ camera_ids = torch .range (0 , len (cameras ), dtype = torch .long )
208
+
161
209
batch_size = cameras .R .shape [0 ]
162
210
device = cameras .device
163
211
164
212
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
165
213
xy_grid = self ._xy_grid .to (device ).expand (batch_size , - 1 , - 1 , - 1 )
166
214
167
- num_rays = n_rays_per_image or self ._n_rays_per_image
168
- if mask is not None and num_rays is None :
215
+ if mask is not None and n_rays_per_image is None :
169
216
# if num rays not given, sample according to the smallest mask
170
- num_rays = num_rays or mask .sum (dim = (1 , 2 )).min ().int ().item ()
217
+ n_rays_per_image = (
218
+ n_rays_per_image or mask .sum (dim = (1 , 2 )).min ().int ().item ()
219
+ )
171
220
172
- if num_rays is not None :
221
+ if n_rays_per_image is not None :
173
222
if mask is not None :
174
223
assert mask .shape == xy_grid .shape [:3 ]
175
224
weights = mask .reshape (batch_size , - 1 )
@@ -181,7 +230,9 @@ def forward(
181
230
weights = xy_grid .new_ones (batch_size , width * height )
182
231
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
183
232
# float, int]`.
184
- rays_idx = _safe_multinomial (weights , num_rays )[..., None ].expand (- 1 , - 1 , 2 )
233
+ rays_idx = _safe_multinomial (weights , n_rays_per_image )[..., None ].expand (
234
+ - 1 , - 1 , 2
235
+ )
185
236
186
237
xy_grid = torch .gather (xy_grid .reshape (batch_size , - 1 , 2 ), 1 , rays_idx )[
187
238
:, :, None
@@ -198,7 +249,7 @@ def forward(
198
249
else self ._stratified_sampling
199
250
)
200
251
201
- return _xy_to_ray_bundle (
252
+ ray_bundle = _xy_to_ray_bundle (
202
253
cameras ,
203
254
xy_grid ,
204
255
min_depth ,
@@ -208,6 +259,13 @@ def forward(
208
259
stratified_sampling ,
209
260
)
210
261
262
+ return (
263
+ # pyre-ignore[61]
264
+ _pack_ray_bundle (ray_bundle , camera_ids , camera_counts )
265
+ if n_rays_total
266
+ else ray_bundle
267
+ )
268
+
211
269
212
270
class NDCMultinomialRaysampler (MultinomialRaysampler ):
213
271
"""
@@ -231,6 +289,7 @@ def __init__(
231
289
min_depth : float ,
232
290
max_depth : float ,
233
291
n_rays_per_image : Optional [int ] = None ,
292
+ n_rays_total : Optional [int ] = None ,
234
293
unit_directions : bool = False ,
235
294
stratified_sampling : bool = False ,
236
295
) -> None :
@@ -254,6 +313,7 @@ def __init__(
254
313
min_depth = min_depth ,
255
314
max_depth = max_depth ,
256
315
n_rays_per_image = n_rays_per_image ,
316
+ n_rays_total = n_rays_total ,
257
317
unit_directions = unit_directions ,
258
318
stratified_sampling = stratified_sampling ,
259
319
)
@@ -281,6 +341,7 @@ def __init__(
281
341
min_depth : float ,
282
342
max_depth : float ,
283
343
* ,
344
+ n_rays_total : Optional [int ] = None ,
284
345
unit_directions : bool = False ,
285
346
stratified_sampling : bool = False ,
286
347
) -> None :
@@ -294,6 +355,11 @@ def __init__(
294
355
n_pts_per_ray: The number of points sampled along each ray.
295
356
min_depth: The minimum depth of each ray-point.
296
357
max_depth: The maximum depth of each ray-point.
358
+ n_rays_total: How many rays in total to sample from the cameras provided. The result
359
+ is as if `n_rays_total_training` cameras were sampled with replacement from the
360
+ cameras provided and for every camera one ray was sampled. If set, this disables
361
+ `n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with
362
+ batch_size=n_rays_total.
297
363
unit_directions: whether to normalize direction vectors in ray bundle.
298
364
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
299
365
bins for each ray; otherwise takes n_pts_per_ray deterministic points
@@ -308,6 +374,7 @@ def __init__(
308
374
self ._n_pts_per_ray = n_pts_per_ray
309
375
self ._min_depth = min_depth
310
376
self ._max_depth = max_depth
377
+ self ._n_rays_total = n_rays_total
311
378
self ._unit_directions = unit_directions
312
379
self ._stratified_sampling = stratified_sampling
313
380
@@ -317,15 +384,16 @@ def forward(
317
384
* ,
318
385
stratified_sampling : Optional [bool ] = None ,
319
386
** kwargs ,
320
- ) -> RayBundle :
387
+ ) -> Union [ RayBundle , HeterogeneousRayBundle ] :
321
388
"""
322
389
Args:
323
390
cameras: A batch of `batch_size` cameras from which the rays are emitted.
324
391
stratified_sampling: if set, overrides stratified_sampling provided
325
392
in __init__.
326
-
327
393
Returns:
328
- A named tuple RayBundle with the following fields:
394
+ A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the
395
+ following fields:
396
+
329
397
origins: A tensor of shape
330
398
`(batch_size, n_rays_per_image, 3)`
331
399
denoting the locations of ray origins in the world coordinates.
@@ -338,7 +406,31 @@ def forward(
338
406
xys: A tensor of shape
339
407
`(batch_size, n_rays_per_image, 2)`
340
408
containing the 2D image coordinates of each ray.
409
+ If `n_rays_total` is provided `batch_size=n_rays_total`and
410
+ `n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle`
411
+ is returned.
412
+
413
+ `HeterogeneousRayBundle` has additional members:
414
+ - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
415
+ cameras. It represents unique ids of sampled cameras.
416
+ - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
417
+ cameras. Represents how many times each camera from `camera_ids` was sampled
341
418
"""
419
+ assert (self ._n_rays_total is None ) or (
420
+ self ._n_rays_per_image is None
421
+ ), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined."
422
+
423
+ if self ._n_rays_total :
424
+ (
425
+ cameras ,
426
+ _ ,
427
+ camera_ids ,
428
+ camera_counts ,
429
+ n_rays_per_image ,
430
+ ) = _sample_cameras_and_masks (self ._n_rays_total , cameras , None )
431
+ else :
432
+ camera_ids = torch .range (0 , len (cameras ), dtype = torch .long )
433
+ n_rays_per_image = self ._n_rays_per_image
342
434
343
435
batch_size = cameras .R .shape [0 ]
344
436
@@ -349,7 +441,7 @@ def forward(
349
441
rays_xy = torch .cat (
350
442
[
351
443
torch .rand (
352
- size = (batch_size , self . _n_rays_per_image , 1 ),
444
+ size = (batch_size , n_rays_per_image , 1 ),
353
445
dtype = torch .float32 ,
354
446
device = device ,
355
447
)
@@ -369,7 +461,7 @@ def forward(
369
461
else self ._stratified_sampling
370
462
)
371
463
372
- return _xy_to_ray_bundle (
464
+ ray_bundle = _xy_to_ray_bundle (
373
465
cameras ,
374
466
rays_xy ,
375
467
self ._min_depth ,
@@ -379,6 +471,13 @@ def forward(
379
471
stratified_sampling ,
380
472
)
381
473
474
+ return (
475
+ # pyre-ignore[61]
476
+ _pack_ray_bundle (ray_bundle , camera_ids , camera_counts )
477
+ if self ._n_rays_total
478
+ else ray_bundle
479
+ )
480
+
382
481
383
482
# Settings for backwards compatibility
384
483
def GridRaysampler (
@@ -602,3 +701,74 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
602
701
# Samples in those intervals.
603
702
jiggled = lower + (upper - lower ) * torch .rand_like (lower )
604
703
return jiggled
704
+
705
+
706
+ def _sample_cameras_and_masks (
707
+ n_samples : int , cameras : CamerasBase , mask : Optional [torch .Tensor ] = None
708
+ ) -> Tuple [
709
+ CamerasBase , Optional [torch .Tensor ], torch .Tensor , torch .Tensor , torch .Tensor
710
+ ]:
711
+ """
712
+ Samples n_rays_total cameras and masks and returns them in a form
713
+ (camera_idx, count), where count represents number of times the same camera
714
+ has been sampled.
715
+
716
+ Args:
717
+ n_samples: how many camera and mask pairs to sample
718
+ cameras: A batch of `batch_size` cameras from which the rays are emitted.
719
+ mask: Optional. Should be of size (batch_size, image_height, image_width).
720
+ Returns:
721
+ tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids,
722
+ number_of_times_each_sampled_camera_has_been_sampled,
723
+ max_number_of_times_camera_has_been_sampled,
724
+ )
725
+ """
726
+ sampled_ids = torch .randint (
727
+ 0 ,
728
+ len (cameras ),
729
+ size = (n_samples ,),
730
+ dtype = torch .long ,
731
+ )
732
+ unique_ids , counts = torch .unique (sampled_ids , return_counts = True )
733
+ return (
734
+ cameras [unique_ids ],
735
+ mask [unique_ids ] if mask is not None else None ,
736
+ unique_ids ,
737
+ counts ,
738
+ torch .max (counts ),
739
+ )
740
+
741
+
742
+ def _pack_ray_bundle (
743
+ ray_bundle : RayBundle , camera_ids : torch .Tensor , camera_counts : torch .Tensor
744
+ ) -> HeterogeneousRayBundle :
745
+ """
746
+ Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
747
+ [total_num_rays, 1, ...]
748
+
749
+ Args:
750
+ ray_bundle: A ray_bundle to pack
751
+ camera_ids: Unique ids of cameras that were sampled
752
+ camera_counts: how many of which camera to pack, each count coresponds to
753
+ one 'row' of the ray_bundle and says how many rays wll be taken
754
+ from it and packed.
755
+ Returns:
756
+ HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
757
+ """
758
+ camera_counts = camera_counts .to (ray_bundle .origins .device )
759
+ cumsum = torch .cumsum (camera_counts , dim = 0 , dtype = torch .long )
760
+ first_idxs = torch .cat (
761
+ (camera_counts .new_zeros ((1 ,), dtype = torch .long ), cumsum [:- 1 ])
762
+ )
763
+ num_inputs = int (camera_counts .sum ())
764
+
765
+ return HeterogeneousRayBundle (
766
+ origins = padded_to_packed (ray_bundle .origins , first_idxs , num_inputs )[:, None ],
767
+ directions = padded_to_packed (ray_bundle .directions , first_idxs , num_inputs )[
768
+ :, None
769
+ ],
770
+ lengths = padded_to_packed (ray_bundle .lengths , first_idxs , num_inputs )[:, None ],
771
+ xys = padded_to_packed (ray_bundle .xys , first_idxs , num_inputs )[:, None ],
772
+ camera_ids = camera_ids ,
773
+ camera_counts = camera_counts ,
774
+ )
0 commit comments