Skip to content

Commit 9e2bc3a

Browse files
bottlerfacebook-github-bot
authored andcommitted
ambient lights batching #1043
Summary: convert_to_tensors_and_broadcast had a special case for a single input, which is not used anywhere except fails to do the right thing if a TensorProperties has only one kwarg. At the moment AmbientLights may be the only way to hit the problem. Fix by removing the special case. Fixes #1043 Reviewed By: nikhilaravi Differential Revision: D33638345 fbshipit-source-id: 7a6695f44242e650504320f73b6da74254d49ac7
1 parent fddd6a7 commit 9e2bc3a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

pytorch3d/renderer/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,4 @@ def convert_to_tensors_and_broadcast(
349349
expand_sizes = (N,) + (-1,) * len(c.shape[1:])
350350
args_Nd.append(c.expand(*expand_sizes))
351351

352-
if len(args) == 1:
353-
args_Nd = args_Nd[0] # Return the first element
354-
355352
return args_Nd

tests/test_lighting.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import torch
1111
from common_testing import TestCaseMixin
12-
from pytorch3d.renderer.lighting import DirectionalLights, PointLights
12+
from pytorch3d.renderer.lighting import AmbientLights, DirectionalLights, PointLights
1313
from pytorch3d.transforms import RotateAxisAngle
1414

1515

@@ -121,6 +121,17 @@ def test_initialize_lights_dimensions_fail(self):
121121
with self.assertRaises(ValueError):
122122
PointLights(location=torch.randn(10, 4))
123123

124+
def test_initialize_ambient(self):
125+
N = 13
126+
color = 0.8 * torch.ones((N, 3))
127+
lights = AmbientLights(ambient_color=color)
128+
self.assertEqual(len(lights), N)
129+
self.assertClose(lights.ambient_color, color)
130+
131+
lights = AmbientLights(ambient_color=color[:1])
132+
self.assertEqual(len(lights), 1)
133+
self.assertClose(lights.ambient_color, color[:1])
134+
124135

125136
class TestDiffuseLighting(TestCaseMixin, unittest.TestCase):
126137
def test_diffuse_directional_lights(self):

0 commit comments

Comments
 (0)