13
13
14
14
# Creates a skew-symmetric tensor from a vector
15
15
def vector_to_skewtensor (vector ):
16
- tensor = torch .cross (
17
- * torch .broadcast_tensors (
18
- vector [..., None ], torch .eye (3 , 3 , device = vector .device , dtype = vector .dtype )[None , None ]
19
- )
16
+ batch_size = vector .size (0 )
17
+ zero = torch .zeros (batch_size , device = vector .device , dtype = vector .dtype )
18
+ tensor = torch .stack (
19
+ (
20
+ zero ,
21
+ - vector [:, 2 ],
22
+ vector [:, 1 ],
23
+ vector [:, 2 ],
24
+ zero ,
25
+ - vector [:, 0 ],
26
+ - vector [:, 1 ],
27
+ vector [:, 0 ],
28
+ zero ,
29
+ ),
30
+ dim = 1 ,
20
31
)
32
+ tensor = tensor .view (- 1 , 3 , 3 )
21
33
return tensor .squeeze (0 )
22
34
23
35
@@ -43,9 +55,9 @@ def decompose_tensor(tensor):
43
55
44
56
# Modifies tensor by multiplying invariant features to irreducible components
45
57
def new_radial_tensor (I , A , S , f_I , f_A , f_S ):
46
- I = ( f_I ) [..., None , None ] * I
47
- A = ( f_A ) [..., None , None ] * A
48
- S = ( f_S ) [..., None , None ] * S
58
+ I = f_I [..., None , None ] * I
59
+ A = f_A [..., None , None ] * A
60
+ S = f_S [..., None , None ] * S
49
61
return I , A , S
50
62
51
63
@@ -102,6 +114,7 @@ def __init__(
102
114
dtype = torch .float32 ,
103
115
):
104
116
super (TensorNet , self ).__init__ ()
117
+
105
118
assert rbf_type in rbf_class_mapping , (
106
119
f'Unknown RBF type "{ rbf_type } ". '
107
120
f'Choose from { ", " .join (rbf_class_mapping .keys ())} .'
@@ -110,6 +123,7 @@ def __init__(
110
123
f'Unknown activation function "{ activation } ". '
111
124
f'Choose from { ", " .join (act_class_mapping .keys ())} .'
112
125
)
126
+
113
127
assert equivariance_invariance_group in ["O(3)" , "SO(3)" ], (
114
128
f'Unknown group "{ equivariance_invariance_group } ". '
115
129
f"Choose O(3) or SO(3)."
@@ -139,6 +153,7 @@ def __init__(
139
153
max_z ,
140
154
dtype ,
141
155
).jittable ()
156
+
142
157
self .layers = nn .ModuleList ()
143
158
if num_layers != 0 :
144
159
for _ in range (num_layers ):
@@ -160,23 +175,34 @@ def __init__(
160
175
161
176
def reset_parameters (self ):
162
177
self .tensor_embedding .reset_parameters ()
163
- for i in range ( self .num_layers ) :
164
- self . layers [ i ] .reset_parameters ()
178
+ for layer in self .layers :
179
+ layer .reset_parameters ()
165
180
self .linear .reset_parameters ()
166
181
self .out_norm .reset_parameters ()
167
182
168
183
def forward (
169
- self , z , pos , batch , q : Optional [Tensor ] = None , s : Optional [Tensor ] = None
170
- ):
184
+ self ,
185
+ z : Tensor ,
186
+ pos : Tensor ,
187
+ batch : Tensor ,
188
+ q : Optional [Tensor ] = None ,
189
+ s : Optional [Tensor ] = None ,
190
+ ) -> Tuple [Tensor , Optional [Tensor ], Tensor , Tensor , Tensor ]:
191
+
171
192
# Obtain graph, with distances and relative position vectors
172
193
edge_index , edge_weight , edge_vec = self .distance (pos , batch )
194
+ # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor]
195
+ assert (
196
+ edge_vec is not None
197
+ ), "Distance module did not return directional information"
198
+
173
199
# Expand distances with radial basis functions
174
200
edge_attr = self .distance_expansion (edge_weight )
175
201
# Embedding from edge-wise tensors to node-wise tensors
176
202
X = self .tensor_embedding (z , edge_index , edge_weight , edge_vec , edge_attr )
177
203
# Interaction layers
178
- for i in range ( self .num_layers ) :
179
- X = self . layers [ i ] (X , edge_index , edge_weight , edge_attr )
204
+ for layer in self .layers :
205
+ X = layer (X , edge_index , edge_weight , edge_attr )
180
206
I , A , S = decompose_tensor (X )
181
207
x = torch .cat ((tensor_norm (I ), tensor_norm (A ), tensor_norm (S )), dim = - 1 )
182
208
x = self .out_norm (x )
@@ -208,15 +234,10 @@ def __init__(
208
234
self .emb2 = nn .Linear (2 * hidden_channels , hidden_channels , dtype = dtype )
209
235
self .act = activation ()
210
236
self .linears_tensor = nn .ModuleList ()
211
- self .linears_tensor .append (
212
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
213
- )
214
- self .linears_tensor .append (
215
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
216
- )
217
- self .linears_tensor .append (
218
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
219
- )
237
+ for _ in range (3 ):
238
+ self .linears_tensor .append (
239
+ nn .Linear (hidden_channels , hidden_channels , bias = False )
240
+ )
220
241
self .linears_scalar = nn .ModuleList ()
221
242
self .linears_scalar .append (
222
243
nn .Linear (hidden_channels , 2 * hidden_channels , bias = True , dtype = dtype )
@@ -239,16 +260,26 @@ def reset_parameters(self):
239
260
linear .reset_parameters ()
240
261
self .init_norm .reset_parameters ()
241
262
242
- def forward (self , z , edge_index , edge_weight , edge_vec , edge_attr ):
263
+ def forward (
264
+ self ,
265
+ z : Tensor ,
266
+ edge_index : Tensor ,
267
+ edge_weight : Tensor ,
268
+ edge_vec : Tensor ,
269
+ edge_attr : Tensor ,
270
+ ):
271
+
243
272
Z = self .emb (z )
244
273
C = self .cutoff (edge_weight )
245
- W1 = ( self .distance_proj1 (edge_attr ) ) * C .view (- 1 , 1 )
246
- W2 = ( self .distance_proj2 (edge_attr ) ) * C .view (- 1 , 1 )
247
- W3 = ( self .distance_proj3 (edge_attr ) ) * C .view (- 1 , 1 )
274
+ W1 = self .distance_proj1 (edge_attr ) * C .view (- 1 , 1 )
275
+ W2 = self .distance_proj2 (edge_attr ) * C .view (- 1 , 1 )
276
+ W3 = self .distance_proj3 (edge_attr ) * C .view (- 1 , 1 )
248
277
mask = edge_index [0 ] != edge_index [1 ]
249
278
edge_vec [mask ] = edge_vec [mask ] / torch .norm (edge_vec [mask ], dim = 1 ).unsqueeze (1 )
250
279
Iij , Aij , Sij = new_radial_tensor (
251
- torch .eye (3 , 3 , device = edge_vec .device , dtype = edge_vec .dtype )[None , None , :, :],
280
+ torch .eye (3 , 3 , device = edge_vec .device , dtype = edge_vec .dtype )[
281
+ None , None , :, :
282
+ ],
252
283
vector_to_skewtensor (edge_vec )[..., None , :, :],
253
284
vector_to_symtensor (edge_vec )[..., None , :, :],
254
285
W1 ,
@@ -262,11 +293,12 @@ def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr):
262
293
I = self .linears_tensor [0 ](I .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
263
294
A = self .linears_tensor [1 ](A .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
264
295
S = self .linears_tensor [2 ](S .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
265
- for j in range ( len ( self .linears_scalar )) :
266
- norm = self .act (self . linears_scalar [ j ] (norm ))
296
+ for linear_scalar in self .linears_scalar :
297
+ norm = self .act (linear_scalar (norm ))
267
298
norm = norm .reshape (norm .shape [0 ], self .hidden_channels , 3 )
268
299
I , A , S = new_radial_tensor (I , A , S , norm [..., 0 ], norm [..., 1 ], norm [..., 2 ])
269
300
X = I + A + S
301
+
270
302
return X
271
303
272
304
def message (self , Z_i , Z_j , I , A , S ):
@@ -275,6 +307,7 @@ def message(self, Z_i, Z_j, I, A, S):
275
307
I = Zij [..., None , None ] * I
276
308
A = Zij [..., None , None ] * A
277
309
S = Zij [..., None , None ] * S
310
+
278
311
return I , A , S
279
312
280
313
def aggregate (
@@ -284,10 +317,12 @@ def aggregate(
284
317
ptr : Optional [torch .Tensor ],
285
318
dim_size : Optional [int ],
286
319
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
320
+
287
321
I , A , S = features
288
322
I = scatter (I , index , dim = self .node_dim , dim_size = dim_size )
289
323
A = scatter (A , index , dim = self .node_dim , dim_size = dim_size )
290
324
S = scatter (S , index , dim = self .node_dim , dim_size = dim_size )
325
+
291
326
return I , A , S
292
327
293
328
def update (
@@ -321,24 +356,10 @@ def __init__(
321
356
nn .Linear (2 * hidden_channels , 3 * hidden_channels , bias = True , dtype = dtype )
322
357
)
323
358
self .linears_tensor = nn .ModuleList ()
324
- self .linears_tensor .append (
325
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
326
- )
327
- self .linears_tensor .append (
328
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
329
- )
330
- self .linears_tensor .append (
331
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
332
- )
333
- self .linears_tensor .append (
334
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
335
- )
336
- self .linears_tensor .append (
337
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
338
- )
339
- self .linears_tensor .append (
340
- nn .Linear (hidden_channels , hidden_channels , bias = False , dtype = dtype )
341
- )
359
+ for _ in range (6 ):
360
+ self .linears_tensor .append (
361
+ nn .Linear (hidden_channels , hidden_channels , bias = False )
362
+ )
342
363
self .act = activation ()
343
364
self .equivariance_invariance_group = equivariance_invariance_group
344
365
self .reset_parameters ()
@@ -350,9 +371,10 @@ def reset_parameters(self):
350
371
linear .reset_parameters ()
351
372
352
373
def forward (self , X , edge_index , edge_weight , edge_attr ):
374
+
353
375
C = self .cutoff (edge_weight )
354
- for i in range ( len ( self .linears_scalar )) :
355
- edge_attr = self .act (self . linears_scalar [ i ] (edge_attr ))
376
+ for linear_scalar in self .linears_scalar :
377
+ edge_attr = self .act (linear_scalar (edge_attr ))
356
378
edge_attr = (edge_attr * C .view (- 1 , 1 )).reshape (
357
379
edge_attr .shape [0 ], self .hidden_channels , 3
358
380
)
@@ -374,19 +396,17 @@ def forward(self, X, edge_index, edge_weight, edge_attr):
374
396
if self .equivariance_invariance_group == "SO(3)" :
375
397
B = torch .matmul (Y , msg )
376
398
I , A , S = decompose_tensor (2 * B )
377
- norm = tensor_norm (I + A + S )
378
- I = I / (norm + 1 )[..., None , None ]
379
- A = A / (norm + 1 )[..., None , None ]
380
- S = S / (norm + 1 )[..., None , None ]
399
+ normp1 = (tensor_norm (I + A + S ) + 1 )[..., None , None ]
400
+ I , A , S = I / normp1 , A / normp1 , S / normp1
381
401
I = self .linears_tensor [3 ](I .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
382
402
A = self .linears_tensor [4 ](A .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
383
403
S = self .linears_tensor [5 ](S .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
384
404
dX = I + A + S
385
- dX = dX + torch .matmul (dX , dX )
386
- X = X + dX
405
+ X = X + dX + dX ** 2
387
406
return X
388
407
389
408
def message (self , I_j , A_j , S_j , edge_attr ):
409
+
390
410
I , A , S = new_radial_tensor (
391
411
I_j , A_j , S_j , edge_attr [..., 0 ], edge_attr [..., 1 ], edge_attr [..., 2 ]
392
412
)
@@ -399,6 +419,7 @@ def aggregate(
399
419
ptr : Optional [torch .Tensor ],
400
420
dim_size : Optional [int ],
401
421
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
422
+
402
423
I , A , S = features
403
424
I = scatter (I , index , dim = self .node_dim , dim_size = dim_size )
404
425
A = scatter (A , index , dim = self .node_dim , dim_size = dim_size )
0 commit comments