Skip to content

Commit 74bbd6f

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Fix returning a proper rotation in levelling; supporting batches and default centroid
Summary: `get_rotation_to_best_fit_xy` is useful to expose externally, however there was a bug (which we probably did not care about for our use case): it could return a rotation matrix with det(R) == −1. The diff fixes that, and also makes centroid optional (it can be computed from points). Reviewed By: bottler Differential Revision: D39926791 fbshipit-source-id: 5120c7892815b829f3ddcc23e93d4a5ec0ca0013
1 parent de98c9c commit 74bbd6f

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

pytorch3d/implicitron/tools/circle_fitting.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,30 @@
1212
import torch
1313

1414

15-
def _get_rotation_to_best_fit_xy(
16-
points: torch.Tensor, centroid: torch.Tensor
15+
def get_rotation_to_best_fit_xy(
16+
points: torch.Tensor, centroid: Optional[torch.Tensor] = None
1717
) -> torch.Tensor:
1818
"""
19-
Returns a rotation r such that points @ r has a best fit plane
19+
Returns a rotation R such that `points @ R` has a best fit plane
2020
parallel to the xy plane
2121
2222
Args:
23-
points: (N, 3) tensor of points in 3D
24-
centroid: (3,) their centroid
23+
points: (*, N, 3) tensor of points in 3D
24+
centroid: (*, 1, 3), (3,) or scalar: their centroid
2525
2626
Returns:
27-
(3,3) tensor rotation matrix
27+
(*, 3, 3) tensor rotation matrix
2828
"""
29-
points_centered = points - centroid[None]
30-
return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
29+
if centroid is None:
30+
centroid = points.mean(dim=-2, keepdim=True)
31+
32+
points_centered = points - centroid
33+
_, evec = torch.linalg.eigh(points_centered.transpose(-1, -2) @ points_centered)
34+
# in general, evec can form either right- or left-handed basis,
35+
# but we need the former to have a proper rotation (not reflection)
36+
return torch.cat(
37+
(evec[..., 1:], torch.cross(evec[..., 1], evec[..., 2])[..., None]), dim=-1
38+
)
3139

3240

3341
def _signed_area(path: torch.Tensor) -> torch.Tensor:
@@ -191,7 +199,7 @@ def fit_circle_in_3d(
191199
Circle3D object
192200
"""
193201
centroid = points.mean(0)
194-
r = _get_rotation_to_best_fit_xy(points, centroid)
202+
r = get_rotation_to_best_fit_xy(points, centroid)
195203
normal = r[:, 2]
196204
rotated_points = (points - centroid) @ r
197205
result_2d = fit_circle_in_2d(

tests/implicitron/test_circle_fitting.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
_signed_area,
1313
fit_circle_in_2d,
1414
fit_circle_in_3d,
15+
get_rotation_to_best_fit_xy,
1516
)
16-
from pytorch3d.transforms import random_rotation
17+
from pytorch3d.transforms import random_rotation, random_rotations
1718
from tests.common_testing import TestCaseMixin
1819

1920

@@ -28,6 +29,32 @@ def _assertParallel(self, a, b, **kwargs):
2829
"""
2930
self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs)
3031

32+
def test_plane_levelling(self):
33+
device = torch.device("cuda:0")
34+
B = 16
35+
N = 1024
36+
random = torch.randn((B, N, 3), device=device)
37+
38+
# first, check that we always return a vaild rotation
39+
rot = get_rotation_to_best_fit_xy(random)
40+
self.assertClose(rot.det(), torch.ones_like(rot[:, 0, 0]))
41+
self.assertClose(rot.norm(dim=-1), torch.ones_like(rot[:, 0]))
42+
43+
# then, check the result is what we expect
44+
z_squeeze = 0.1
45+
random[..., -1] *= z_squeeze
46+
rot_gt = random_rotations(B, device=device)
47+
rotated = random @ rot_gt.transpose(-1, -2)
48+
rot_hat = get_rotation_to_best_fit_xy(rotated)
49+
self.assertClose(rot.det(), torch.ones_like(rot[:, 0, 0]))
50+
self.assertClose(rot.norm(dim=-1), torch.ones_like(rot[:, 0]))
51+
# covariance matrix of the levelled points is by design diag(1, 1, z_squeeze²)
52+
self.assertClose(
53+
(rotated @ rot_hat)[..., -1].std(dim=-1),
54+
torch.ones_like(rot_hat[:, 0, 0]) * z_squeeze,
55+
rtol=0.1,
56+
)
57+
3158
def test_simple_3d(self):
3259
device = torch.device("cuda:0")
3360
for _ in range(7):

0 commit comments

Comments
 (0)