Skip to content

Commit 5ffeb4d

Browse files
Norman Muellerfacebook-github-bot
Norman Mueller
authored andcommitted
Single directional chamfer distance and non-absolute cosine similarity
Summary: Single directional chamfer distance and option to use non-absolute cosine similarity Reviewed By: bottler Differential Revision: D46593980 fbshipit-source-id: b2e591706a0cdde1c2d361614cecebb84a581433
1 parent 573a42c commit 5ffeb4d

File tree

2 files changed

+310
-105
lines changed

2 files changed

+310
-105
lines changed

pytorch3d/loss/chamfer.py

Lines changed: 114 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -68,74 +68,28 @@ def _handle_pointcloud_input(
6868
return X, lengths, normals
6969

7070

71-
def chamfer_distance(
71+
def _chamfer_distance_single_direction(
7272
x,
7373
y,
74-
x_lengths=None,
75-
y_lengths=None,
76-
x_normals=None,
77-
y_normals=None,
78-
weights=None,
79-
batch_reduction: Union[str, None] = "mean",
80-
point_reduction: str = "mean",
81-
norm: int = 2,
74+
x_lengths,
75+
y_lengths,
76+
x_normals,
77+
y_normals,
78+
weights,
79+
batch_reduction: Union[str, None],
80+
point_reduction: str,
81+
norm: int,
82+
abs_cosine: bool,
8283
):
83-
"""
84-
Chamfer distance between two pointclouds x and y.
85-
86-
Args:
87-
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
88-
a batch of point clouds with at most P1 points in each batch element,
89-
batch size N and feature dimension D.
90-
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
91-
a batch of point clouds with at most P2 points in each batch element,
92-
batch size N and feature dimension D.
93-
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
94-
cloud in x.
95-
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
96-
cloud in y.
97-
x_normals: Optional FloatTensor of shape (N, P1, D).
98-
y_normals: Optional FloatTensor of shape (N, P2, D).
99-
weights: Optional FloatTensor of shape (N,) giving weights for
100-
batch elements for reduction operation.
101-
batch_reduction: Reduction operation to apply for the loss across the
102-
batch, can be one of ["mean", "sum"] or None.
103-
point_reduction: Reduction operation to apply for the loss across the
104-
points, can be one of ["mean", "sum"].
105-
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
106-
107-
Returns:
108-
2-element tuple containing
109-
110-
- **loss**: Tensor giving the reduced distance between the pointclouds
111-
in x and the pointclouds in y.
112-
- **loss_normals**: Tensor giving the reduced cosine distance of normals
113-
between pointclouds in x and pointclouds in y. Returns None if
114-
x_normals and y_normals are None.
115-
"""
116-
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
117-
118-
if not ((norm == 1) or (norm == 2)):
119-
raise ValueError("Support for 1 or 2 norm.")
120-
121-
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
122-
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
123-
12484
return_normals = x_normals is not None and y_normals is not None
12585

12686
N, P1, D = x.shape
127-
P2 = y.shape[1]
12887

12988
# Check if inputs are heterogeneous and create a lengths mask.
13089
is_x_heterogeneous = (x_lengths != P1).any()
131-
is_y_heterogeneous = (y_lengths != P2).any()
13290
x_mask = (
13391
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
13492
) # shape [N, P1]
135-
y_mask = (
136-
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
137-
) # shape [N, P2]
138-
13993
if y.shape[0] != N or y.shape[2] != D:
14094
raise ValueError("y does not have the correct shape.")
14195
if weights is not None:
@@ -153,75 +107,148 @@ def chamfer_distance(
153107
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
154108

155109
cham_norm_x = x.new_zeros(())
156-
cham_norm_y = x.new_zeros(())
157110

158111
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
159-
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1)
160-
161112
cham_x = x_nn.dists[..., 0] # (N, P1)
162-
cham_y = y_nn.dists[..., 0] # (N, P2)
163113

164114
if is_x_heterogeneous:
165115
cham_x[x_mask] = 0.0
166-
if is_y_heterogeneous:
167-
cham_y[y_mask] = 0.0
168116

169117
if weights is not None:
170118
cham_x *= weights.view(N, 1)
171-
cham_y *= weights.view(N, 1)
172119

173120
if return_normals:
174121
# Gather the normals using the indices and keep only value for k=0
175122
x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
176-
y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :]
177123

178-
cham_norm_x = 1 - torch.abs(
179-
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
180-
)
181-
cham_norm_y = 1 - torch.abs(
182-
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
183-
)
124+
cosine_sim = F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
125+
# If abs_cosine, ignore orientation and take the absolute value of the cosine sim.
126+
cham_norm_x = 1 - (torch.abs(cosine_sim) if abs_cosine else cosine_sim)
184127

185128
if is_x_heterogeneous:
186129
cham_norm_x[x_mask] = 0.0
187-
if is_y_heterogeneous:
188-
cham_norm_y[y_mask] = 0.0
189130

190131
if weights is not None:
191132
cham_norm_x *= weights.view(N, 1)
192-
cham_norm_y *= weights.view(N, 1)
133+
cham_norm_x = cham_norm_x.sum(1) # (N,)
193134

194135
# Apply point reduction
195136
cham_x = cham_x.sum(1) # (N,)
196-
cham_y = cham_y.sum(1) # (N,)
197-
if return_normals:
198-
cham_norm_x = cham_norm_x.sum(1) # (N,)
199-
cham_norm_y = cham_norm_y.sum(1) # (N,)
200137
if point_reduction == "mean":
201138
x_lengths_clamped = x_lengths.clamp(min=1)
202-
y_lengths_clamped = y_lengths.clamp(min=1)
203139
cham_x /= x_lengths_clamped
204-
cham_y /= y_lengths_clamped
205140
if return_normals:
206141
cham_norm_x /= x_lengths_clamped
207-
cham_norm_y /= y_lengths_clamped
208142

209143
if batch_reduction is not None:
210144
# batch_reduction == "sum"
211145
cham_x = cham_x.sum()
212-
cham_y = cham_y.sum()
213146
if return_normals:
214147
cham_norm_x = cham_norm_x.sum()
215-
cham_norm_y = cham_norm_y.sum()
216148
if batch_reduction == "mean":
217149
div = weights.sum() if weights is not None else max(N, 1)
218150
cham_x /= div
219-
cham_y /= div
220151
if return_normals:
221152
cham_norm_x /= div
222-
cham_norm_y /= div
223-
224-
cham_dist = cham_x + cham_y
225-
cham_normals = cham_norm_x + cham_norm_y if return_normals else None
226153

154+
cham_dist = cham_x
155+
cham_normals = cham_norm_x if return_normals else None
227156
return cham_dist, cham_normals
157+
158+
159+
def chamfer_distance(
160+
x,
161+
y,
162+
x_lengths=None,
163+
y_lengths=None,
164+
x_normals=None,
165+
y_normals=None,
166+
weights=None,
167+
batch_reduction: Union[str, None] = "mean",
168+
point_reduction: str = "mean",
169+
norm: int = 2,
170+
single_directional: bool = False,
171+
abs_cosine: bool = True,
172+
):
173+
"""
174+
Chamfer distance between two pointclouds x and y.
175+
176+
Args:
177+
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
178+
a batch of point clouds with at most P1 points in each batch element,
179+
batch size N and feature dimension D.
180+
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
181+
a batch of point clouds with at most P2 points in each batch element,
182+
batch size N and feature dimension D.
183+
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
184+
cloud in x.
185+
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
186+
cloud in y.
187+
x_normals: Optional FloatTensor of shape (N, P1, D).
188+
y_normals: Optional FloatTensor of shape (N, P2, D).
189+
weights: Optional FloatTensor of shape (N,) giving weights for
190+
batch elements for reduction operation.
191+
batch_reduction: Reduction operation to apply for the loss across the
192+
batch, can be one of ["mean", "sum"] or None.
193+
point_reduction: Reduction operation to apply for the loss across the
194+
points, can be one of ["mean", "sum"].
195+
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
196+
single_directional: If False (default), loss comes from both the distance between
197+
each point in x and its nearest neighbor in y and each point in y and its nearest
198+
neighbor in x. If True, loss is the distance between each point in x and its
199+
nearest neighbor in y.
200+
abs_cosine: If False, loss_normals is from one minus the cosine similarity.
201+
If True (default), loss_normals is from one minus the absolute value of the
202+
cosine similarity, which means that exactly opposite normals are considered
203+
equivalent to exactly matching normals, i.e. sign does not matter.
204+
205+
Returns:
206+
2-element tuple containing
207+
208+
- **loss**: Tensor giving the reduced distance between the pointclouds
209+
in x and the pointclouds in y.
210+
- **loss_normals**: Tensor giving the reduced cosine distance of normals
211+
between pointclouds in x and pointclouds in y. Returns None if
212+
x_normals and y_normals are None.
213+
214+
"""
215+
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
216+
217+
if not ((norm == 1) or (norm == 2)):
218+
raise ValueError("Support for 1 or 2 norm.")
219+
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
220+
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
221+
222+
cham_x, cham_norm_x = _chamfer_distance_single_direction(
223+
x,
224+
y,
225+
x_lengths,
226+
y_lengths,
227+
x_normals,
228+
y_normals,
229+
weights,
230+
batch_reduction,
231+
point_reduction,
232+
norm,
233+
abs_cosine,
234+
)
235+
if single_directional:
236+
return cham_x, cham_norm_x
237+
else:
238+
cham_y, cham_norm_y = _chamfer_distance_single_direction(
239+
y,
240+
x,
241+
y_lengths,
242+
x_lengths,
243+
y_normals,
244+
x_normals,
245+
weights,
246+
batch_reduction,
247+
point_reduction,
248+
norm,
249+
abs_cosine,
250+
)
251+
return (
252+
cham_x + cham_y,
253+
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
254+
)

0 commit comments

Comments
 (0)