Skip to content

Commit 72c3a0e

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
raybundle input to ImplicitFunctions -> api unification
Summary: Currently some implicit functions in implicitron take a raybundle, others take ray_points_world. raybundle is what they really need. However, the raybundle is going to become a bit more flexible later, as it will contain different numbers of rays for each camera. Reviewed By: bottler Differential Revision: D39173751 fbshipit-source-id: ebc038e426d22e831e67a18ba64655d8a61e1eb9
1 parent 70dc9c4 commit 72c3a0e

File tree

9 files changed

+60
-19
lines changed

9 files changed

+60
-19
lines changed

pytorch3d/implicitron/models/implicit_function/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self):
1919
@abstractmethod
2020
def forward(
2121
self,
22+
*,
2223
ray_bundle: RayBundle,
2324
fun_viewpool=None,
2425
camera: Optional[CamerasBase] = None,

pytorch3d/implicitron/models/implicit_function/idr_feature_field.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
# implicit_differentiable_renderer.py
44
# Copyright (c) 2020 Lior Yariv
55
import math
6-
from typing import Tuple
6+
from typing import Optional, Tuple
77

88
import torch
99
from pytorch3d.implicitron.tools.config import registry
10-
from pytorch3d.renderer.implicit import HarmonicEmbedding
10+
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
1111
from torch import nn
1212

1313
from .base import ImplicitFunctionBase
14+
from .utils import get_rays_points_world
1415

1516

1617
@registry.register
@@ -125,14 +126,16 @@ def __post_init__(self):
125126
# inconsistently.
126127
def forward(
127128
self,
128-
# ray_bundle: RayBundle,
129-
rays_points_world: torch.Tensor, # TODO: unify the APIs
129+
*,
130+
ray_bundle: Optional[RayBundle] = None,
131+
rays_points_world: Optional[torch.Tensor] = None,
130132
fun_viewpool=None,
131133
global_code=None,
134+
**kwargs,
132135
):
133136
# this field only uses point locations
134-
# rays_points_world = ray_bundle_to_ray_points(ray_bundle)
135137
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
138+
rays_points_world = get_rays_points_world(ray_bundle, rays_points_world)
136139

137140
if rays_points_world.numel() == 0 or (
138141
self.embed_fn is None and fun_viewpool is None and global_code is None
@@ -179,4 +182,4 @@ def forward(
179182
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
180183
x = self.softplus(x)
181184

182-
return x # TODO: unify the APIs
185+
return x

pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def allows_multiple_passes() -> bool:
129129

130130
def forward(
131131
self,
132+
*,
132133
ray_bundle: RayBundle,
133134
fun_viewpool=None,
134135
camera: Optional[CamerasBase] = None,

pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
349349

350350
def forward(
351351
self,
352+
*,
352353
ray_bundle: RayBundle,
353354
fun_viewpool=None,
354355
camera: Optional[CamerasBase] = None,
@@ -408,6 +409,7 @@ def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
408409

409410
def forward(
410411
self,
412+
*,
411413
ray_bundle: RayBundle,
412414
fun_viewpool=None,
413415
camera: Optional[CamerasBase] = None,

pytorch3d/implicitron/models/implicit_function/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
import torch.nn.functional as F
1212
from pytorch3d.common.compat import prod
13+
from pytorch3d.renderer import ray_bundle_to_ray_points
1314
from pytorch3d.renderer.cameras import CamerasBase
15+
from pytorch3d.renderer.implicit import RayBundle
1416

1517

1618
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
@@ -185,3 +187,31 @@ def interpolate_volume(
185187
**kwargs,
186188
)
187189
return out[:, :, :, 0, 0].permute(0, 2, 1)
190+
191+
192+
def get_rays_points_world(
193+
ray_bundle: Optional[RayBundle] = None,
194+
rays_points_world: Optional[torch.Tensor] = None,
195+
) -> torch.Tensor:
196+
"""
197+
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
198+
and raises error if both are defined.
199+
200+
Args:
201+
ray_bundle: A RayBundle object or None
202+
rays_points_world: A torch.Tensor representing ray points converted to
203+
world coordinates
204+
Returns:
205+
A torch.Tensor representing ray points converted to world coordinates
206+
of shape [minibatch x ... x pts_per_ray x 3].
207+
"""
208+
if rays_points_world is not None and ray_bundle is not None:
209+
raise ValueError(
210+
"Cannot define both rays_points_world and ray_bundle,"
211+
+ " one has to be None."
212+
)
213+
if rays_points_world is not None:
214+
return rays_points_world
215+
if ray_bundle is not None:
216+
return ray_bundle_to_ray_points(ray_bundle)
217+
raise ValueError("ray_bundle and rays_points_world cannot both be None")

pytorch3d/implicitron/models/renderer/lstm_renderer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def forward(
118118

119119
# eval the raymarching function
120120
raymarch_features, _ = implicit_function(
121-
ray_bundle_t,
121+
ray_bundle=ray_bundle_t,
122122
raymarch_features=None,
123123
)
124124
if self.verbose:

pytorch3d/implicitron/models/renderer/multipass_ea.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _run_raymarcher(
148148
)
149149

150150
output = self.raymarcher(
151-
*implicit_functions[0](ray_bundle),
151+
*implicit_functions[0](ray_bundle=ray_bundle),
152152
ray_lengths=ray_bundle.lengths,
153153
density_noise_std=density_noise_std,
154154
)

pytorch3d/implicitron/models/renderer/sdf_renderer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def forward(
101101
object_mask = object_mask.bool()
102102

103103
implicit_function = implicit_functions[0]
104-
implicit_function_gradient = functools.partial(gradient, implicit_function)
104+
implicit_function_gradient = functools.partial(_gradient, implicit_function)
105105

106106
# object_mask: silhouette of the object
107107
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
@@ -113,7 +113,7 @@ def forward(
113113

114114
with torch.no_grad(), evaluating(implicit_function):
115115
points, network_object_mask, dists = self.ray_tracer(
116-
sdf=lambda x: implicit_function(x)[
116+
sdf=lambda x: implicit_function(rays_points_world=x)[
117117
:, 0
118118
], # TODO: get rid of this wrapper
119119
cam_loc=cam_loc,
@@ -125,7 +125,7 @@ def forward(
125125
depth = dists.reshape(batch_size, num_pixels, 1)
126126
points = (cam_loc + depth * ray_dirs).reshape(-1, 3)
127127

128-
sdf_output = implicit_function(points)[:, 0:1]
128+
sdf_output = implicit_function(rays_points_world=points)[:, 0:1]
129129
# NOTE most of the intermediate variables are flattened for
130130
# no apparent reason (here and in the ray tracer)
131131
ray_dirs = ray_dirs.reshape(-1, 3)
@@ -157,7 +157,7 @@ def forward(
157157

158158
points_all = torch.cat([surface_points, eikonal_points], dim=0)
159159

160-
output = implicit_function(surface_points)
160+
output = implicit_function(rays_points_world=surface_points)
161161
surface_sdf_values = output[
162162
:N, 0:1
163163
].detach() # how is it different from sdf_output?
@@ -181,7 +181,9 @@ def forward(
181181
grad_theta = None
182182

183183
empty_render = differentiable_surface_points.shape[0] == 0
184-
features = implicit_function(differentiable_surface_points)[None, :, 1:]
184+
features = implicit_function(rays_points_world=differentiable_surface_points)[
185+
None, :, 1:
186+
]
185187
normals_full = features.new_zeros(
186188
batch_size, *spatial_size, 3, requires_grad=empty_render
187189
)
@@ -260,13 +262,13 @@ def _sample_network(
260262

261263

262264
@torch.enable_grad()
263-
def gradient(module, x):
264-
x.requires_grad_(True)
265-
y = module.forward(x)[:, :1]
265+
def _gradient(module, rays_points_world):
266+
rays_points_world.requires_grad_(True)
267+
y = module.forward(rays_points_world=rays_points_world)[:, :1]
266268
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
267269
gradients = torch.autograd.grad(
268270
outputs=y,
269-
inputs=x,
271+
inputs=rays_points_world,
270272
grad_outputs=d_output,
271273
create_graph=True,
272274
retain_graph=True,

tests/implicitron/test_srn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_srn_implicit_function(self):
4444
implicit_function = SRNImplicitFunction()
4545
device = torch.device("cpu")
4646
bundle = self._get_bundle(device=device)
47-
rays_densities, rays_colors = implicit_function(bundle)
47+
rays_densities, rays_colors = implicit_function(ray_bundle=bundle)
4848
out_features = implicit_function.raymarch_function.out_features
4949
self.assertEqual(
5050
rays_densities.shape,
@@ -62,7 +62,9 @@ def test_srn_hypernet_implicit_function(self):
6262
implicit_function.to(device)
6363
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
6464
bundle = self._get_bundle(device=device)
65-
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
65+
rays_densities, rays_colors = implicit_function(
66+
ray_bundle=bundle, global_code=global_code
67+
)
6668
out_features = implicit_function.hypernet.out_features
6769
self.assertEqual(
6870
rays_densities.shape,

0 commit comments

Comments
 (0)