16
16
import torch
17
17
from pytorch3d .implicitron .tools .config import ReplaceableBase
18
18
from pytorch3d .ops import packed_to_padded
19
+ from pytorch3d .renderer .implicit .utils import ray_bundle_variables_to_ray_points
19
20
20
21
21
22
class EvaluationMode (Enum ):
@@ -47,6 +48,27 @@ class ImplicitronRayBundle:
47
48
camera_counts: A tensor of shape (N, ) which how many times the
48
49
coresponding camera in `camera_ids` was sampled.
49
50
`sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`.
51
+
52
+ Attributes:
53
+ origins: A tensor of shape `(..., 3)` denoting the
54
+ origins of the sampling rays in world coords.
55
+ directions: A tensor of shape `(..., 3)` containing the direction
56
+ vectors of sampling rays in world coords. They don't have to be normalized;
57
+ they define unit vectors in the respective 1D coordinate systems; see
58
+ documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
59
+ lengths: A tensor of shape `(..., num_points_per_ray)`
60
+ containing the lengths at which the rays are sampled.
61
+ xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
62
+ camera_ids: An optional tensor of shape (N, ) which indicates which camera
63
+ was used to sample the rays. `N` is the number of unique sampled cameras.
64
+ camera_counts: An optional tensor of shape (N, ) indicates how many times the
65
+ coresponding camera in `camera_ids` was sampled.
66
+ `sum(camera_counts)==total_number_of_rays`.
67
+ bins: An optional tensor of shape `(..., num_points_per_ray + 1)`
68
+ containing the bins at which the rays are sampled. In this case
69
+ lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
70
+ pixel_radii_2d: An optional tensor of shape `(..., 1)`
71
+ base radii of the conical frustums.
50
72
"""
51
73
52
74
origins : torch .Tensor
@@ -55,6 +77,45 @@ class ImplicitronRayBundle:
55
77
xys : torch .Tensor
56
78
camera_ids : Optional [torch .LongTensor ] = None
57
79
camera_counts : Optional [torch .LongTensor ] = None
80
+ bins : Optional [torch .Tensor ] = None
81
+ pixel_radii_2d : Optional [torch .Tensor ] = None
82
+
83
+ @classmethod
84
+ def from_bins (
85
+ cls ,
86
+ origins : torch .Tensor ,
87
+ directions : torch .Tensor ,
88
+ bins : torch .Tensor ,
89
+ xys : torch .Tensor ,
90
+ ** kwargs ,
91
+ ) -> "ImplicitronRayBundle" :
92
+ """
93
+ Creates a new instance from bins instead of lengths.
94
+
95
+ Attributes:
96
+ origins: A tensor of shape `(..., 3)` denoting the
97
+ origins of the sampling rays in world coords.
98
+ directions: A tensor of shape `(..., 3)` containing the direction
99
+ vectors of sampling rays in world coords. They don't have to be normalized;
100
+ they define unit vectors in the respective 1D coordinate systems; see
101
+ documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
102
+ bins: A tensor of shape `(..., num_points_per_ray + 1)`
103
+ containing the bins at which the rays are sampled. In this case
104
+ lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
105
+ xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
106
+ kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
107
+ Returns:
108
+ An instance of ImplicitronRayBundle.
109
+ """
110
+
111
+ if bins .shape [- 1 ] <= 1 :
112
+ raise ValueError (
113
+ "The last dim of bins must be at least superior or equal to 2."
114
+ )
115
+ # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
116
+ lengths = torch .lerp (bins [..., 1 :], bins [..., :- 1 ], 0.5 )
117
+
118
+ return cls (origins , directions , lengths , xys , bins = bins , ** kwargs )
58
119
59
120
def is_packed (self ) -> bool :
60
121
"""
@@ -195,3 +256,154 @@ def forward(
195
256
instance of RendererOutput
196
257
"""
197
258
pass
259
+
260
+
261
+ def compute_3d_diagonal_covariance_gaussian (
262
+ rays_directions : torch .Tensor ,
263
+ rays_dir_variance : torch .Tensor ,
264
+ radii_variance : torch .Tensor ,
265
+ eps : float = 1e-6 ,
266
+ ) -> torch .Tensor :
267
+ """
268
+ Transform the variances (rays_dir_variance, radii_variance) of the gaussians from
269
+ the coordinate frame of the conical frustum to 3D world coordinates.
270
+
271
+ It follows the equation 16 of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_
272
+
273
+ Args:
274
+ rays_directions: A tensor of shape `(..., 3)`
275
+ rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
276
+ the variance of the conical frustum with respect to the rays direction.
277
+ radii_variance: A tensor of shape `(..., num_intervals)` representing
278
+ the variance of the conical frustum with respect to its radius.
279
+ eps: a small number to prevent division by zero.
280
+
281
+ Returns:
282
+ A tensor of shape `(..., num_intervals, 3)` containing the diagonal
283
+ of the covariance matrix.
284
+ """
285
+ d_outer_diag = torch .pow (rays_directions , 2 )
286
+ dir_mag_sq = torch .clamp (torch .sum (d_outer_diag , dim = - 1 , keepdim = True ), min = eps )
287
+
288
+ null_outer_diag = 1 - d_outer_diag / dir_mag_sq
289
+ ray_dir_cov_diag = rays_dir_variance [..., None ] * d_outer_diag [..., None , :]
290
+ xy_cov_diag = radii_variance [..., None ] * null_outer_diag [..., None , :]
291
+ return ray_dir_cov_diag + xy_cov_diag
292
+
293
+
294
+ def approximate_conical_frustum_as_gaussians (
295
+ bins : torch .Tensor , radii : torch .Tensor
296
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
297
+ """
298
+ Approximates a conical frustum as two Gaussian distributions.
299
+
300
+ The Gaussian distributions are characterized by
301
+ three values:
302
+
303
+ - rays_dir_mean: mean along the rays direction
304
+ (defined as t in the parametric representation of a cone).
305
+ - rays_dir_variance: the variance of the conical frustum along the rays direction.
306
+ - radii_variance: variance of the conical frustum with respect to its radius.
307
+
308
+
309
+ The computation is stable and follows equation 7
310
+ of `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
311
+
312
+ For more information on how the mean and variances are computed
313
+ refers to the appendix of the paper.
314
+
315
+ Args:
316
+ bins: A tensor of shape `(..., num_points_per_ray + 1)`
317
+ containing the bins at which the rays are sampled.
318
+ `bin[..., t]` and `bin[..., t+1]` represent respectively
319
+ the left and right coordinates of the interval.
320
+ t0: A tensor of shape `(..., num_points_per_ray)`
321
+ containing the left coordinates of the intervals
322
+ on which the rays are sampled.
323
+ t1: A tensor of shape `(..., num_points_per_ray)`
324
+ containing the rights coordinates of the intervals
325
+ on which the rays are sampled.
326
+ radii: A tensor of shape `(..., 1)`
327
+ base radii of the conical frustums.
328
+
329
+ Returns:
330
+ rays_dir_mean: A tensor of shape `(..., num_intervals)` representing
331
+ the mean along the rays direction
332
+ (t in the parametric represention of the cone)
333
+ rays_dir_variance: A tensor of shape `(..., num_intervals)` representing
334
+ the variance of the conical frustum along the rays
335
+ (t in the parametric represention of the cone).
336
+ radii_variance: A tensor of shape `(..., num_intervals)` representing
337
+ the variance of the conical frustum with respect to its radius.
338
+ """
339
+ t_mu = torch .lerp (bins [..., 1 :], bins [..., :- 1 ], 0.5 )
340
+ t_delta = torch .diff (bins , dim = - 1 ) / 2
341
+
342
+ t_mu_pow2 = torch .pow (t_mu , 2 )
343
+ t_delta_pow2 = torch .pow (t_delta , 2 )
344
+ t_delta_pow4 = torch .pow (t_delta , 4 )
345
+
346
+ den = 3 * t_mu_pow2 + t_delta_pow2
347
+
348
+ # mean along the rays direction
349
+ rays_dir_mean = t_mu + 2 * t_mu * t_delta_pow2 / den
350
+
351
+ # Variance of the conical frustum with along the rays directions
352
+ rays_dir_variance = t_delta_pow2 / 3 - (4 / 15 ) * (
353
+ t_delta_pow4 * (12 * t_mu_pow2 - t_delta_pow2 ) / torch .pow (den , 2 )
354
+ )
355
+
356
+ # Variance of the conical frustum with respect to its radius
357
+ radii_variance = torch .pow (radii , 2 ) * (
358
+ t_mu_pow2 / 4 + (5 / 12 ) * t_delta_pow2 - 4 / 15 * (t_delta_pow4 ) / den
359
+ )
360
+ return rays_dir_mean , rays_dir_variance , radii_variance
361
+
362
+
363
+ def conical_frustum_to_gaussian (
364
+ ray_bundle : ImplicitronRayBundle ,
365
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
366
+ """
367
+ Approximate a conical frustum following a ray bundle as a Gaussian.
368
+
369
+ Args:
370
+ ray_bundle: A `RayBundle` or `HeterogeneousRayBundle` object with fields:
371
+ origins: A tensor of shape `(..., 3)`
372
+ directions: A tensor of shape `(..., 3)`
373
+ lengths: A tensor of shape `(..., num_points_per_ray)`
374
+ bins: A tensor of shape `(..., num_points_per_ray + 1)`
375
+ containing the bins at which the rays are sampled. .
376
+ pixel_radii_2d: A tensor of shape `(..., 1)`
377
+ base radii of the conical frustums.
378
+
379
+ Returns:
380
+ means: A tensor of shape `(..., num_points_per_ray - 1, 3)`
381
+ representing the means of the Gaussians
382
+ approximating the conical frustums.
383
+ diag_covariances: A tensor of shape `(...,num_points_per_ray -1, 3)`
384
+ representing the diagonal covariance matrices of our Gaussians.
385
+ """
386
+
387
+ if ray_bundle .pixel_radii_2d is None or ray_bundle .bins is None :
388
+ raise ValueError (
389
+ "RayBundle pixel_radii_2d or bins have not been provided."
390
+ " Look at pytorch3d.renderer.implicit.renderer.ray_sampler::"
391
+ "AbstractMaskRaySampler to see how to compute them. Have you forgot to set"
392
+ "`cast_ray_bundle_as_cone` to True?"
393
+ )
394
+
395
+ (
396
+ rays_dir_mean ,
397
+ rays_dir_variance ,
398
+ radii_variance ,
399
+ ) = approximate_conical_frustum_as_gaussians (
400
+ ray_bundle .bins ,
401
+ ray_bundle .pixel_radii_2d ,
402
+ )
403
+ means = ray_bundle_variables_to_ray_points (
404
+ ray_bundle .origins , ray_bundle .directions , rays_dir_mean
405
+ )
406
+ diag_covariances = compute_3d_diagonal_covariance_gaussian (
407
+ ray_bundle .directions , rays_dir_variance , radii_variance
408
+ )
409
+ return means , diag_covariances
0 commit comments