@@ -196,10 +196,10 @@ def __getitem__(
196
196
index = [index ]
197
197
return self .__class__ (matrix = self .get_matrix ()[index ])
198
198
199
- def compose (self , * others ) :
199
+ def compose (self , * others : "Transform3d" ) -> "Transform3d" :
200
200
"""
201
- Return a new Transform3d with the transforms to compose stored as
202
- an internal list.
201
+ Return a new Transform3d representing the composition of self with the
202
+ given other transforms, which will be stored as an internal list.
203
203
204
204
Args:
205
205
*others: Any number of Transform3d objects
@@ -216,7 +216,7 @@ def compose(self, *others):
216
216
out ._transforms = self ._transforms + list (others )
217
217
return out
218
218
219
- def get_matrix (self ):
219
+ def get_matrix (self ) -> torch . Tensor :
220
220
"""
221
221
Return a matrix which is the result of composing this transform
222
222
with others stored in self.transforms. Where necessary transforms
@@ -240,13 +240,13 @@ def get_matrix(self):
240
240
composed_matrix = _broadcast_bmm (composed_matrix , other_matrix )
241
241
return composed_matrix
242
242
243
- def _get_matrix_inverse (self ):
243
+ def _get_matrix_inverse (self ) -> torch . Tensor :
244
244
"""
245
245
Return the inverse of self._matrix.
246
246
"""
247
247
return torch .inverse (self ._matrix )
248
248
249
- def inverse (self , invert_composed : bool = False ):
249
+ def inverse (self , invert_composed : bool = False ) -> "Transform3d" :
250
250
"""
251
251
Returns a new Transform3d object that represents an inverse of the
252
252
current transformation.
@@ -295,14 +295,24 @@ def inverse(self, invert_composed: bool = False):
295
295
296
296
return tinv
297
297
298
- def stack (self , * others ):
298
+ def stack (self , * others : "Transform3d" ) -> "Transform3d" :
299
+ """
300
+ Return a new batched Transform3d representing the batch elements from
301
+ self and all the given other transforms all batched together.
302
+
303
+ Args:
304
+ *others: Any number of Transform3d objects
305
+
306
+ Returns:
307
+ A new Transform3d.
308
+ """
299
309
transforms = [self ] + list (others )
300
- matrix = torch .cat ([t ._matrix for t in transforms ], dim = 0 )
310
+ matrix = torch .cat ([t .get_matrix () for t in transforms ], dim = 0 )
301
311
out = Transform3d (dtype = self .dtype , device = self .device )
302
312
out ._matrix = matrix
303
313
return out
304
314
305
- def transform_points (self , points , eps : Optional [float ] = None ):
315
+ def transform_points (self , points , eps : Optional [float ] = None ) -> torch . Tensor :
306
316
"""
307
317
Use this transform to transform a set of 3D points. Assumes row major
308
318
ordering of the input points.
@@ -347,7 +357,7 @@ def transform_points(self, points, eps: Optional[float] = None):
347
357
348
358
return points_out
349
359
350
- def transform_normals (self , normals ):
360
+ def transform_normals (self , normals ) -> torch . Tensor :
351
361
"""
352
362
Use this transform to transform a set of normal vectors.
353
363
@@ -379,19 +389,19 @@ def transform_normals(self, normals):
379
389
380
390
return normals_out
381
391
382
- def translate (self , * args , ** kwargs ):
392
+ def translate (self , * args , ** kwargs ) -> "Transform3d" :
383
393
return self .compose (Translate (device = self .device , * args , ** kwargs ))
384
394
385
- def scale (self , * args , ** kwargs ):
395
+ def scale (self , * args , ** kwargs ) -> "Transform3d" :
386
396
return self .compose (Scale (device = self .device , * args , ** kwargs ))
387
397
388
- def rotate (self , * args , ** kwargs ):
398
+ def rotate (self , * args , ** kwargs ) -> "Transform3d" :
389
399
return self .compose (Rotate (device = self .device , * args , ** kwargs ))
390
400
391
- def rotate_axis_angle (self , * args , ** kwargs ):
401
+ def rotate_axis_angle (self , * args , ** kwargs ) -> "Transform3d" :
392
402
return self .compose (RotateAxisAngle (device = self .device , * args , ** kwargs ))
393
403
394
- def clone (self ):
404
+ def clone (self ) -> "Transform3d" :
395
405
"""
396
406
Deep copy of Transforms object. All internal tensors are cloned
397
407
individually.
@@ -411,7 +421,7 @@ def to(
411
421
device : Device ,
412
422
copy : bool = False ,
413
423
dtype : Optional [torch .dtype ] = None ,
414
- ):
424
+ ) -> "Transform3d" :
415
425
"""
416
426
Match functionality of torch.Tensor.to()
417
427
If copy = True or the self Tensor is on a different device, the
@@ -448,10 +458,10 @@ def to(
448
458
]
449
459
return other
450
460
451
- def cpu (self ):
461
+ def cpu (self ) -> "Transform3d" :
452
462
return self .to ("cpu" )
453
463
454
- def cuda (self ):
464
+ def cuda (self ) -> "Transform3d" :
455
465
return self .to ("cuda" )
456
466
457
467
@@ -486,7 +496,7 @@ def __init__(
486
496
mat [:, 3 , :3 ] = xyz
487
497
self ._matrix = mat
488
498
489
- def _get_matrix_inverse (self ):
499
+ def _get_matrix_inverse (self ) -> torch . Tensor :
490
500
"""
491
501
Return the inverse of self._matrix.
492
502
"""
@@ -533,7 +543,7 @@ def __init__(
533
543
mat [:, 2 , 2 ] = xyz [:, 2 ]
534
544
self ._matrix = mat
535
545
536
- def _get_matrix_inverse (self ):
546
+ def _get_matrix_inverse (self ) -> torch . Tensor :
537
547
"""
538
548
Return the inverse of self._matrix.
539
549
"""
@@ -575,7 +585,7 @@ def __init__(
575
585
mat [:, :3 , :3 ] = R
576
586
self ._matrix = mat
577
587
578
- def _get_matrix_inverse (self ):
588
+ def _get_matrix_inverse (self ) -> torch . Tensor :
579
589
"""
580
590
Return the inverse of self._matrix.
581
591
"""
@@ -622,7 +632,7 @@ def __init__(
622
632
super ().__init__ (device = angle .device , R = R )
623
633
624
634
625
- def _handle_coord (c , dtype : torch .dtype , device : torch .device ):
635
+ def _handle_coord (c , dtype : torch .dtype , device : torch .device ) -> torch . Tensor :
626
636
"""
627
637
Helper function for _handle_input.
628
638
@@ -649,7 +659,7 @@ def _handle_input(
649
659
device : Optional [Device ],
650
660
name : str ,
651
661
allow_singleton : bool = False ,
652
- ):
662
+ ) -> torch . Tensor :
653
663
"""
654
664
Helper function to handle parsing logic for building transforms. The output
655
665
is always a tensor of shape (N, 3), but there are several types of allowed
@@ -707,7 +717,9 @@ def _handle_input(
707
717
return xyz
708
718
709
719
710
- def _handle_angle_input (x , dtype : torch .dtype , device : Optional [Device ], name : str ):
720
+ def _handle_angle_input (
721
+ x , dtype : torch .dtype , device : Optional [Device ], name : str
722
+ ) -> torch .Tensor :
711
723
"""
712
724
Helper function for building a rotation function using angles.
713
725
The output is always of shape (N,).
@@ -725,7 +737,7 @@ def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: s
725
737
return _handle_coord (x , dtype , device_ )
726
738
727
739
728
- def _broadcast_bmm (a , b ):
740
+ def _broadcast_bmm (a , b ) -> torch . Tensor :
729
741
"""
730
742
Batch multiply two matrices and broadcast if necessary.
731
743
0 commit comments