Skip to content

Commit c8f3d6b

Browse files
bottlerfacebook-github-bot
authored andcommitted
Fix Transform3d.stack of compositions
Summary: Add a test for Transform3d.stack, and make it work with composed transformations. Fixes #1072 . Reviewed By: patricklabatut Differential Revision: D34211920 fbshipit-source-id: bfbd0895494ca2ad3d08a61bc82ba23637e168cc
1 parent 2a1de3b commit c8f3d6b

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

pytorch3d/renderer/cameras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1649,7 +1649,7 @@ def look_at_view_transform(
16491649
elev=0.0,
16501650
azim=0.0,
16511651
degrees: bool = True,
1652-
eye: Optional[Sequence] = None,
1652+
eye: Optional[Union[Sequence, torch.Tensor]] = None,
16531653
at=((0, 0, 0),), # (1, 3)
16541654
up=((0, 1, 0),), # (1, 3)
16551655
device: Device = "cpu",

pytorch3d/transforms/transform3d.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ def __getitem__(
196196
index = [index]
197197
return self.__class__(matrix=self.get_matrix()[index])
198198

199-
def compose(self, *others):
199+
def compose(self, *others: "Transform3d") -> "Transform3d":
200200
"""
201-
Return a new Transform3d with the transforms to compose stored as
202-
an internal list.
201+
Return a new Transform3d representing the composition of self with the
202+
given other transforms, which will be stored as an internal list.
203203
204204
Args:
205205
*others: Any number of Transform3d objects
@@ -216,7 +216,7 @@ def compose(self, *others):
216216
out._transforms = self._transforms + list(others)
217217
return out
218218

219-
def get_matrix(self):
219+
def get_matrix(self) -> torch.Tensor:
220220
"""
221221
Return a matrix which is the result of composing this transform
222222
with others stored in self.transforms. Where necessary transforms
@@ -240,13 +240,13 @@ def get_matrix(self):
240240
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
241241
return composed_matrix
242242

243-
def _get_matrix_inverse(self):
243+
def _get_matrix_inverse(self) -> torch.Tensor:
244244
"""
245245
Return the inverse of self._matrix.
246246
"""
247247
return torch.inverse(self._matrix)
248248

249-
def inverse(self, invert_composed: bool = False):
249+
def inverse(self, invert_composed: bool = False) -> "Transform3d":
250250
"""
251251
Returns a new Transform3d object that represents an inverse of the
252252
current transformation.
@@ -295,14 +295,24 @@ def inverse(self, invert_composed: bool = False):
295295

296296
return tinv
297297

298-
def stack(self, *others):
298+
def stack(self, *others: "Transform3d") -> "Transform3d":
299+
"""
300+
Return a new batched Transform3d representing the batch elements from
301+
self and all the given other transforms all batched together.
302+
303+
Args:
304+
*others: Any number of Transform3d objects
305+
306+
Returns:
307+
A new Transform3d.
308+
"""
299309
transforms = [self] + list(others)
300-
matrix = torch.cat([t._matrix for t in transforms], dim=0)
310+
matrix = torch.cat([t.get_matrix() for t in transforms], dim=0)
301311
out = Transform3d(dtype=self.dtype, device=self.device)
302312
out._matrix = matrix
303313
return out
304314

305-
def transform_points(self, points, eps: Optional[float] = None):
315+
def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor:
306316
"""
307317
Use this transform to transform a set of 3D points. Assumes row major
308318
ordering of the input points.
@@ -347,7 +357,7 @@ def transform_points(self, points, eps: Optional[float] = None):
347357

348358
return points_out
349359

350-
def transform_normals(self, normals):
360+
def transform_normals(self, normals) -> torch.Tensor:
351361
"""
352362
Use this transform to transform a set of normal vectors.
353363
@@ -379,19 +389,19 @@ def transform_normals(self, normals):
379389

380390
return normals_out
381391

382-
def translate(self, *args, **kwargs):
392+
def translate(self, *args, **kwargs) -> "Transform3d":
383393
return self.compose(Translate(device=self.device, *args, **kwargs))
384394

385-
def scale(self, *args, **kwargs):
395+
def scale(self, *args, **kwargs) -> "Transform3d":
386396
return self.compose(Scale(device=self.device, *args, **kwargs))
387397

388-
def rotate(self, *args, **kwargs):
398+
def rotate(self, *args, **kwargs) -> "Transform3d":
389399
return self.compose(Rotate(device=self.device, *args, **kwargs))
390400

391-
def rotate_axis_angle(self, *args, **kwargs):
401+
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
392402
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
393403

394-
def clone(self):
404+
def clone(self) -> "Transform3d":
395405
"""
396406
Deep copy of Transforms object. All internal tensors are cloned
397407
individually.
@@ -411,7 +421,7 @@ def to(
411421
device: Device,
412422
copy: bool = False,
413423
dtype: Optional[torch.dtype] = None,
414-
):
424+
) -> "Transform3d":
415425
"""
416426
Match functionality of torch.Tensor.to()
417427
If copy = True or the self Tensor is on a different device, the
@@ -448,10 +458,10 @@ def to(
448458
]
449459
return other
450460

451-
def cpu(self):
461+
def cpu(self) -> "Transform3d":
452462
return self.to("cpu")
453463

454-
def cuda(self):
464+
def cuda(self) -> "Transform3d":
455465
return self.to("cuda")
456466

457467

@@ -486,7 +496,7 @@ def __init__(
486496
mat[:, 3, :3] = xyz
487497
self._matrix = mat
488498

489-
def _get_matrix_inverse(self):
499+
def _get_matrix_inverse(self) -> torch.Tensor:
490500
"""
491501
Return the inverse of self._matrix.
492502
"""
@@ -533,7 +543,7 @@ def __init__(
533543
mat[:, 2, 2] = xyz[:, 2]
534544
self._matrix = mat
535545

536-
def _get_matrix_inverse(self):
546+
def _get_matrix_inverse(self) -> torch.Tensor:
537547
"""
538548
Return the inverse of self._matrix.
539549
"""
@@ -575,7 +585,7 @@ def __init__(
575585
mat[:, :3, :3] = R
576586
self._matrix = mat
577587

578-
def _get_matrix_inverse(self):
588+
def _get_matrix_inverse(self) -> torch.Tensor:
579589
"""
580590
Return the inverse of self._matrix.
581591
"""
@@ -622,7 +632,7 @@ def __init__(
622632
super().__init__(device=angle.device, R=R)
623633

624634

625-
def _handle_coord(c, dtype: torch.dtype, device: torch.device):
635+
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
626636
"""
627637
Helper function for _handle_input.
628638
@@ -649,7 +659,7 @@ def _handle_input(
649659
device: Optional[Device],
650660
name: str,
651661
allow_singleton: bool = False,
652-
):
662+
) -> torch.Tensor:
653663
"""
654664
Helper function to handle parsing logic for building transforms. The output
655665
is always a tensor of shape (N, 3), but there are several types of allowed
@@ -707,7 +717,9 @@ def _handle_input(
707717
return xyz
708718

709719

710-
def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
720+
def _handle_angle_input(
721+
x, dtype: torch.dtype, device: Optional[Device], name: str
722+
) -> torch.Tensor:
711723
"""
712724
Helper function for building a rotation function using angles.
713725
The output is always of shape (N,).
@@ -725,7 +737,7 @@ def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: s
725737
return _handle_coord(x, dtype, device_)
726738

727739

728-
def _broadcast_bmm(a, b):
740+
def _broadcast_bmm(a, b) -> torch.Tensor:
729741
"""
730742
Batch multiply two matrices and broadcast if necessary.
731743

tests/test_transforms.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from common_testing import TestCaseMixin
13+
from pytorch3d.transforms import random_rotations
1314
from pytorch3d.transforms.so3 import so3_exp_map
1415
from pytorch3d.transforms.transform3d import (
1516
Rotate,
@@ -21,6 +22,9 @@
2122

2223

2324
class TestTransform(TestCaseMixin, unittest.TestCase):
25+
def setUp(self) -> None:
26+
torch.manual_seed(42)
27+
2428
def test_to(self):
2529
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
2630
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
@@ -406,6 +410,28 @@ def test_get_item(self, batch_size=5):
406410
with self.assertRaises(IndexError):
407411
t3d_selected = t3d[invalid_index]
408412

413+
def test_stack(self):
414+
rotations = random_rotations(3)
415+
transform3 = Transform3d().rotate(rotations).translate(torch.full((3, 3), 0.3))
416+
transform1 = Scale(37)
417+
transform4 = transform1.stack(transform3)
418+
self.assertEqual(len(transform1), 1)
419+
self.assertEqual(len(transform3), 3)
420+
self.assertEqual(len(transform4), 4)
421+
self.assertClose(
422+
transform4.get_matrix(),
423+
torch.cat([transform1.get_matrix(), transform3.get_matrix()]),
424+
)
425+
points = torch.rand(4, 5, 3)
426+
new_points_expect = torch.cat(
427+
[
428+
transform1.transform_points(points[:1]),
429+
transform3.transform_points(points[1:]),
430+
]
431+
)
432+
new_points = transform4.transform_points(points)
433+
self.assertClose(new_points, new_points_expect)
434+
409435

410436
class TestTranslate(unittest.TestCase):
411437
def test_python_scalar(self):

0 commit comments

Comments
 (0)