@@ -68,74 +68,28 @@ def _handle_pointcloud_input(
68
68
return X , lengths , normals
69
69
70
70
71
- def chamfer_distance (
71
+ def _chamfer_distance_single_direction (
72
72
x ,
73
73
y ,
74
- x_lengths = None ,
75
- y_lengths = None ,
76
- x_normals = None ,
77
- y_normals = None ,
78
- weights = None ,
79
- batch_reduction : Union [str , None ] = "mean" ,
80
- point_reduction : str = "mean" ,
81
- norm : int = 2 ,
74
+ x_lengths ,
75
+ y_lengths ,
76
+ x_normals ,
77
+ y_normals ,
78
+ weights ,
79
+ batch_reduction : Union [str , None ],
80
+ point_reduction : str ,
81
+ norm : int ,
82
+ abs_cosine : bool ,
82
83
):
83
- """
84
- Chamfer distance between two pointclouds x and y.
85
-
86
- Args:
87
- x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
88
- a batch of point clouds with at most P1 points in each batch element,
89
- batch size N and feature dimension D.
90
- y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
91
- a batch of point clouds with at most P2 points in each batch element,
92
- batch size N and feature dimension D.
93
- x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
94
- cloud in x.
95
- y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
96
- cloud in y.
97
- x_normals: Optional FloatTensor of shape (N, P1, D).
98
- y_normals: Optional FloatTensor of shape (N, P2, D).
99
- weights: Optional FloatTensor of shape (N,) giving weights for
100
- batch elements for reduction operation.
101
- batch_reduction: Reduction operation to apply for the loss across the
102
- batch, can be one of ["mean", "sum"] or None.
103
- point_reduction: Reduction operation to apply for the loss across the
104
- points, can be one of ["mean", "sum"].
105
- norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
106
-
107
- Returns:
108
- 2-element tuple containing
109
-
110
- - **loss**: Tensor giving the reduced distance between the pointclouds
111
- in x and the pointclouds in y.
112
- - **loss_normals**: Tensor giving the reduced cosine distance of normals
113
- between pointclouds in x and pointclouds in y. Returns None if
114
- x_normals and y_normals are None.
115
- """
116
- _validate_chamfer_reduction_inputs (batch_reduction , point_reduction )
117
-
118
- if not ((norm == 1 ) or (norm == 2 )):
119
- raise ValueError ("Support for 1 or 2 norm." )
120
-
121
- x , x_lengths , x_normals = _handle_pointcloud_input (x , x_lengths , x_normals )
122
- y , y_lengths , y_normals = _handle_pointcloud_input (y , y_lengths , y_normals )
123
-
124
84
return_normals = x_normals is not None and y_normals is not None
125
85
126
86
N , P1 , D = x .shape
127
- P2 = y .shape [1 ]
128
87
129
88
# Check if inputs are heterogeneous and create a lengths mask.
130
89
is_x_heterogeneous = (x_lengths != P1 ).any ()
131
- is_y_heterogeneous = (y_lengths != P2 ).any ()
132
90
x_mask = (
133
91
torch .arange (P1 , device = x .device )[None ] >= x_lengths [:, None ]
134
92
) # shape [N, P1]
135
- y_mask = (
136
- torch .arange (P2 , device = y .device )[None ] >= y_lengths [:, None ]
137
- ) # shape [N, P2]
138
-
139
93
if y .shape [0 ] != N or y .shape [2 ] != D :
140
94
raise ValueError ("y does not have the correct shape." )
141
95
if weights is not None :
@@ -153,75 +107,148 @@ def chamfer_distance(
153
107
return ((x .sum ((1 , 2 )) * weights ) * 0.0 , (x .sum ((1 , 2 )) * weights ) * 0.0 )
154
108
155
109
cham_norm_x = x .new_zeros (())
156
- cham_norm_y = x .new_zeros (())
157
110
158
111
x_nn = knn_points (x , y , lengths1 = x_lengths , lengths2 = y_lengths , norm = norm , K = 1 )
159
- y_nn = knn_points (y , x , lengths1 = y_lengths , lengths2 = x_lengths , norm = norm , K = 1 )
160
-
161
112
cham_x = x_nn .dists [..., 0 ] # (N, P1)
162
- cham_y = y_nn .dists [..., 0 ] # (N, P2)
163
113
164
114
if is_x_heterogeneous :
165
115
cham_x [x_mask ] = 0.0
166
- if is_y_heterogeneous :
167
- cham_y [y_mask ] = 0.0
168
116
169
117
if weights is not None :
170
118
cham_x *= weights .view (N , 1 )
171
- cham_y *= weights .view (N , 1 )
172
119
173
120
if return_normals :
174
121
# Gather the normals using the indices and keep only value for k=0
175
122
x_normals_near = knn_gather (y_normals , x_nn .idx , y_lengths )[..., 0 , :]
176
- y_normals_near = knn_gather (x_normals , y_nn .idx , x_lengths )[..., 0 , :]
177
123
178
- cham_norm_x = 1 - torch .abs (
179
- F .cosine_similarity (x_normals , x_normals_near , dim = 2 , eps = 1e-6 )
180
- )
181
- cham_norm_y = 1 - torch .abs (
182
- F .cosine_similarity (y_normals , y_normals_near , dim = 2 , eps = 1e-6 )
183
- )
124
+ cosine_sim = F .cosine_similarity (x_normals , x_normals_near , dim = 2 , eps = 1e-6 )
125
+ # If abs_cosine, ignore orientation and take the absolute value of the cosine sim.
126
+ cham_norm_x = 1 - (torch .abs (cosine_sim ) if abs_cosine else cosine_sim )
184
127
185
128
if is_x_heterogeneous :
186
129
cham_norm_x [x_mask ] = 0.0
187
- if is_y_heterogeneous :
188
- cham_norm_y [y_mask ] = 0.0
189
130
190
131
if weights is not None :
191
132
cham_norm_x *= weights .view (N , 1 )
192
- cham_norm_y *= weights . view (N , 1 )
133
+ cham_norm_x = cham_norm_x . sum ( 1 ) # (N,)
193
134
194
135
# Apply point reduction
195
136
cham_x = cham_x .sum (1 ) # (N,)
196
- cham_y = cham_y .sum (1 ) # (N,)
197
- if return_normals :
198
- cham_norm_x = cham_norm_x .sum (1 ) # (N,)
199
- cham_norm_y = cham_norm_y .sum (1 ) # (N,)
200
137
if point_reduction == "mean" :
201
138
x_lengths_clamped = x_lengths .clamp (min = 1 )
202
- y_lengths_clamped = y_lengths .clamp (min = 1 )
203
139
cham_x /= x_lengths_clamped
204
- cham_y /= y_lengths_clamped
205
140
if return_normals :
206
141
cham_norm_x /= x_lengths_clamped
207
- cham_norm_y /= y_lengths_clamped
208
142
209
143
if batch_reduction is not None :
210
144
# batch_reduction == "sum"
211
145
cham_x = cham_x .sum ()
212
- cham_y = cham_y .sum ()
213
146
if return_normals :
214
147
cham_norm_x = cham_norm_x .sum ()
215
- cham_norm_y = cham_norm_y .sum ()
216
148
if batch_reduction == "mean" :
217
149
div = weights .sum () if weights is not None else max (N , 1 )
218
150
cham_x /= div
219
- cham_y /= div
220
151
if return_normals :
221
152
cham_norm_x /= div
222
- cham_norm_y /= div
223
-
224
- cham_dist = cham_x + cham_y
225
- cham_normals = cham_norm_x + cham_norm_y if return_normals else None
226
153
154
+ cham_dist = cham_x
155
+ cham_normals = cham_norm_x if return_normals else None
227
156
return cham_dist , cham_normals
157
+
158
+
159
+ def chamfer_distance (
160
+ x ,
161
+ y ,
162
+ x_lengths = None ,
163
+ y_lengths = None ,
164
+ x_normals = None ,
165
+ y_normals = None ,
166
+ weights = None ,
167
+ batch_reduction : Union [str , None ] = "mean" ,
168
+ point_reduction : str = "mean" ,
169
+ norm : int = 2 ,
170
+ single_directional : bool = False ,
171
+ abs_cosine : bool = True ,
172
+ ):
173
+ """
174
+ Chamfer distance between two pointclouds x and y.
175
+
176
+ Args:
177
+ x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
178
+ a batch of point clouds with at most P1 points in each batch element,
179
+ batch size N and feature dimension D.
180
+ y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
181
+ a batch of point clouds with at most P2 points in each batch element,
182
+ batch size N and feature dimension D.
183
+ x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
184
+ cloud in x.
185
+ y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
186
+ cloud in y.
187
+ x_normals: Optional FloatTensor of shape (N, P1, D).
188
+ y_normals: Optional FloatTensor of shape (N, P2, D).
189
+ weights: Optional FloatTensor of shape (N,) giving weights for
190
+ batch elements for reduction operation.
191
+ batch_reduction: Reduction operation to apply for the loss across the
192
+ batch, can be one of ["mean", "sum"] or None.
193
+ point_reduction: Reduction operation to apply for the loss across the
194
+ points, can be one of ["mean", "sum"].
195
+ norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
196
+ single_directional: If False (default), loss comes from both the distance between
197
+ each point in x and its nearest neighbor in y and each point in y and its nearest
198
+ neighbor in x. If True, loss is the distance between each point in x and its
199
+ nearest neighbor in y.
200
+ abs_cosine: If False, loss_normals is from one minus the cosine similarity.
201
+ If True (default), loss_normals is from one minus the absolute value of the
202
+ cosine similarity, which means that exactly opposite normals are considered
203
+ equivalent to exactly matching normals, i.e. sign does not matter.
204
+
205
+ Returns:
206
+ 2-element tuple containing
207
+
208
+ - **loss**: Tensor giving the reduced distance between the pointclouds
209
+ in x and the pointclouds in y.
210
+ - **loss_normals**: Tensor giving the reduced cosine distance of normals
211
+ between pointclouds in x and pointclouds in y. Returns None if
212
+ x_normals and y_normals are None.
213
+
214
+ """
215
+ _validate_chamfer_reduction_inputs (batch_reduction , point_reduction )
216
+
217
+ if not ((norm == 1 ) or (norm == 2 )):
218
+ raise ValueError ("Support for 1 or 2 norm." )
219
+ x , x_lengths , x_normals = _handle_pointcloud_input (x , x_lengths , x_normals )
220
+ y , y_lengths , y_normals = _handle_pointcloud_input (y , y_lengths , y_normals )
221
+
222
+ cham_x , cham_norm_x = _chamfer_distance_single_direction (
223
+ x ,
224
+ y ,
225
+ x_lengths ,
226
+ y_lengths ,
227
+ x_normals ,
228
+ y_normals ,
229
+ weights ,
230
+ batch_reduction ,
231
+ point_reduction ,
232
+ norm ,
233
+ abs_cosine ,
234
+ )
235
+ if single_directional :
236
+ return cham_x , cham_norm_x
237
+ else :
238
+ cham_y , cham_norm_y = _chamfer_distance_single_direction (
239
+ y ,
240
+ x ,
241
+ y_lengths ,
242
+ x_lengths ,
243
+ y_normals ,
244
+ x_normals ,
245
+ weights ,
246
+ batch_reduction ,
247
+ point_reduction ,
248
+ norm ,
249
+ abs_cosine ,
250
+ )
251
+ return (
252
+ cham_x + cham_y ,
253
+ (cham_norm_x + cham_norm_y ) if cham_norm_x is not None else None ,
254
+ )
0 commit comments