Skip to content

Commit a116847

Browse files
authored
Make TensorNet compatible with TorchScript (#186)
* Change some lines incompatible with jit script * Remove some empty lines * Fix typo * Include an assert to appease TorchScript * Change a range loop to an enumerate * Add test for skewtensor function * Small changes from merge * Update test * Update vector_to_skewtensor * Remove some parenthesis * Small changes * Remove skewtensor test * Annotate types in Atomref * Simplify a couple of operations * Check also derivative in torchscript test * Type annotate forward LLNP * Try double backward in the TorchScript test * Change test name * Annotate forward * Remove unused variables * Remove unnecessary enumerates * Add TorchScript GPU tests
1 parent e26dd40 commit a116847

File tree

4 files changed

+128
-66
lines changed

4 files changed

+128
-66
lines changed

tests/test_model.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,47 @@ def test_forward_output_modules(model_name, output_model, dtype):
3737

3838

3939
@mark.parametrize("model_name", models.__all__)
40-
@mark.parametrize("dtype", [torch.float32, torch.float64])
41-
def test_forward_torchscript(model_name, dtype):
42-
if model_name == "tensornet":
43-
pytest.skip("TensorNet does not support torchscript.")
40+
@mark.parametrize("device", ["cpu", "cuda"])
41+
def test_torchscript(model_name, device):
42+
if device == "cuda" and not torch.cuda.is_available():
43+
pytest.skip("CUDA not available")
4444
z, pos, batch = create_example_batch()
45+
z = z.to(device)
46+
pos = pos.to(device)
47+
batch = batch.to(device)
4548
model = torch.jit.script(
46-
create_model(load_example_args(model_name, remove_prior=True, derivative=True, dtype=dtype))
47-
)
48-
model(z, pos, batch=batch)
49+
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
50+
).to(device=device)
51+
y, neg_dy = model(z, pos, batch=batch)
52+
grad_outputs = [torch.ones_like(neg_dy)]
53+
ddy = torch.autograd.grad(
54+
[neg_dy],
55+
[pos],
56+
grad_outputs=grad_outputs,
57+
)[0]
4958

59+
@mark.parametrize("model_name", models.__all__)
60+
@mark.parametrize("device", ["cpu", "cuda"])
61+
def test_torchscript_dynamic_shapes(model_name, device):
62+
if device == "cuda" and not torch.cuda.is_available():
63+
pytest.skip("CUDA not available")
64+
z, pos, batch = create_example_batch()
65+
model = torch.jit.script(
66+
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
67+
).to(device=device)
68+
#Repeat the input to make it dynamic
69+
for rep in range(0, 5):
70+
print(rep)
71+
zi = z.repeat_interleave(rep+1, dim=0).to(device=device)
72+
posi = pos.repeat_interleave(rep+1, dim=0).to(device=device)
73+
batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device)
74+
y, neg_dy = model(zi, posi, batch=batchi)
75+
grad_outputs = [torch.ones_like(neg_dy)]
76+
ddy = torch.autograd.grad(
77+
[neg_dy],
78+
[posi],
79+
grad_outputs=grad_outputs,
80+
)[0]
5081

5182
@mark.parametrize("model_name", models.__all__)
5283
def test_seed(model_name):
@@ -59,7 +90,6 @@ def test_seed(model_name):
5990
for p1, p2 in zip(m1.parameters(), m2.parameters()):
6091
assert (p1 == p2).all(), "Parameters don't match although using the same seed."
6192

62-
6393
@mark.parametrize("model_name", models.__all__)
6494
@mark.parametrize(
6595
"output_model",

torchmdnet/models/tensornet.py

Lines changed: 76 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,23 @@
1313

1414
# Creates a skew-symmetric tensor from a vector
1515
def vector_to_skewtensor(vector):
16-
tensor = torch.cross(
17-
*torch.broadcast_tensors(
18-
vector[..., None], torch.eye(3, 3, device=vector.device, dtype=vector.dtype)[None, None]
19-
)
16+
batch_size = vector.size(0)
17+
zero = torch.zeros(batch_size, device=vector.device, dtype=vector.dtype)
18+
tensor = torch.stack(
19+
(
20+
zero,
21+
-vector[:, 2],
22+
vector[:, 1],
23+
vector[:, 2],
24+
zero,
25+
-vector[:, 0],
26+
-vector[:, 1],
27+
vector[:, 0],
28+
zero,
29+
),
30+
dim=1,
2031
)
32+
tensor = tensor.view(-1, 3, 3)
2133
return tensor.squeeze(0)
2234

2335

@@ -43,9 +55,9 @@ def decompose_tensor(tensor):
4355

4456
# Modifies tensor by multiplying invariant features to irreducible components
4557
def new_radial_tensor(I, A, S, f_I, f_A, f_S):
46-
I = (f_I)[..., None, None] * I
47-
A = (f_A)[..., None, None] * A
48-
S = (f_S)[..., None, None] * S
58+
I = f_I[..., None, None] * I
59+
A = f_A[..., None, None] * A
60+
S = f_S[..., None, None] * S
4961
return I, A, S
5062

5163

@@ -102,6 +114,7 @@ def __init__(
102114
dtype=torch.float32,
103115
):
104116
super(TensorNet, self).__init__()
117+
105118
assert rbf_type in rbf_class_mapping, (
106119
f'Unknown RBF type "{rbf_type}". '
107120
f'Choose from {", ".join(rbf_class_mapping.keys())}.'
@@ -110,6 +123,7 @@ def __init__(
110123
f'Unknown activation function "{activation}". '
111124
f'Choose from {", ".join(act_class_mapping.keys())}.'
112125
)
126+
113127
assert equivariance_invariance_group in ["O(3)", "SO(3)"], (
114128
f'Unknown group "{equivariance_invariance_group}". '
115129
f"Choose O(3) or SO(3)."
@@ -139,6 +153,7 @@ def __init__(
139153
max_z,
140154
dtype,
141155
).jittable()
156+
142157
self.layers = nn.ModuleList()
143158
if num_layers != 0:
144159
for _ in range(num_layers):
@@ -160,23 +175,34 @@ def __init__(
160175

161176
def reset_parameters(self):
162177
self.tensor_embedding.reset_parameters()
163-
for i in range(self.num_layers):
164-
self.layers[i].reset_parameters()
178+
for layer in self.layers:
179+
layer.reset_parameters()
165180
self.linear.reset_parameters()
166181
self.out_norm.reset_parameters()
167182

168183
def forward(
169-
self, z, pos, batch, q: Optional[Tensor] = None, s: Optional[Tensor] = None
170-
):
184+
self,
185+
z: Tensor,
186+
pos: Tensor,
187+
batch: Tensor,
188+
q: Optional[Tensor] = None,
189+
s: Optional[Tensor] = None,
190+
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
191+
171192
# Obtain graph, with distances and relative position vectors
172193
edge_index, edge_weight, edge_vec = self.distance(pos, batch)
194+
# This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor]
195+
assert (
196+
edge_vec is not None
197+
), "Distance module did not return directional information"
198+
173199
# Expand distances with radial basis functions
174200
edge_attr = self.distance_expansion(edge_weight)
175201
# Embedding from edge-wise tensors to node-wise tensors
176202
X = self.tensor_embedding(z, edge_index, edge_weight, edge_vec, edge_attr)
177203
# Interaction layers
178-
for i in range(self.num_layers):
179-
X = self.layers[i](X, edge_index, edge_weight, edge_attr)
204+
for layer in self.layers:
205+
X = layer(X, edge_index, edge_weight, edge_attr)
180206
I, A, S = decompose_tensor(X)
181207
x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)
182208
x = self.out_norm(x)
@@ -208,15 +234,10 @@ def __init__(
208234
self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype)
209235
self.act = activation()
210236
self.linears_tensor = nn.ModuleList()
211-
self.linears_tensor.append(
212-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
213-
)
214-
self.linears_tensor.append(
215-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
216-
)
217-
self.linears_tensor.append(
218-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
219-
)
237+
for _ in range(3):
238+
self.linears_tensor.append(
239+
nn.Linear(hidden_channels, hidden_channels, bias=False)
240+
)
220241
self.linears_scalar = nn.ModuleList()
221242
self.linears_scalar.append(
222243
nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype)
@@ -239,16 +260,26 @@ def reset_parameters(self):
239260
linear.reset_parameters()
240261
self.init_norm.reset_parameters()
241262

242-
def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr):
263+
def forward(
264+
self,
265+
z: Tensor,
266+
edge_index: Tensor,
267+
edge_weight: Tensor,
268+
edge_vec: Tensor,
269+
edge_attr: Tensor,
270+
):
271+
243272
Z = self.emb(z)
244273
C = self.cutoff(edge_weight)
245-
W1 = (self.distance_proj1(edge_attr)) * C.view(-1, 1)
246-
W2 = (self.distance_proj2(edge_attr)) * C.view(-1, 1)
247-
W3 = (self.distance_proj3(edge_attr)) * C.view(-1, 1)
274+
W1 = self.distance_proj1(edge_attr) * C.view(-1, 1)
275+
W2 = self.distance_proj2(edge_attr) * C.view(-1, 1)
276+
W3 = self.distance_proj3(edge_attr) * C.view(-1, 1)
248277
mask = edge_index[0] != edge_index[1]
249278
edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1)
250279
Iij, Aij, Sij = new_radial_tensor(
251-
torch.eye(3, 3, device=edge_vec.device, dtype=edge_vec.dtype)[None, None, :, :],
280+
torch.eye(3, 3, device=edge_vec.device, dtype=edge_vec.dtype)[
281+
None, None, :, :
282+
],
252283
vector_to_skewtensor(edge_vec)[..., None, :, :],
253284
vector_to_symtensor(edge_vec)[..., None, :, :],
254285
W1,
@@ -262,11 +293,12 @@ def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr):
262293
I = self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
263294
A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
264295
S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
265-
for j in range(len(self.linears_scalar)):
266-
norm = self.act(self.linears_scalar[j](norm))
296+
for linear_scalar in self.linears_scalar:
297+
norm = self.act(linear_scalar(norm))
267298
norm = norm.reshape(norm.shape[0], self.hidden_channels, 3)
268299
I, A, S = new_radial_tensor(I, A, S, norm[..., 0], norm[..., 1], norm[..., 2])
269300
X = I + A + S
301+
270302
return X
271303

272304
def message(self, Z_i, Z_j, I, A, S):
@@ -275,6 +307,7 @@ def message(self, Z_i, Z_j, I, A, S):
275307
I = Zij[..., None, None] * I
276308
A = Zij[..., None, None] * A
277309
S = Zij[..., None, None] * S
310+
278311
return I, A, S
279312

280313
def aggregate(
@@ -284,10 +317,12 @@ def aggregate(
284317
ptr: Optional[torch.Tensor],
285318
dim_size: Optional[int],
286319
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
320+
287321
I, A, S = features
288322
I = scatter(I, index, dim=self.node_dim, dim_size=dim_size)
289323
A = scatter(A, index, dim=self.node_dim, dim_size=dim_size)
290324
S = scatter(S, index, dim=self.node_dim, dim_size=dim_size)
325+
291326
return I, A, S
292327

293328
def update(
@@ -321,24 +356,10 @@ def __init__(
321356
nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype)
322357
)
323358
self.linears_tensor = nn.ModuleList()
324-
self.linears_tensor.append(
325-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
326-
)
327-
self.linears_tensor.append(
328-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
329-
)
330-
self.linears_tensor.append(
331-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
332-
)
333-
self.linears_tensor.append(
334-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
335-
)
336-
self.linears_tensor.append(
337-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
338-
)
339-
self.linears_tensor.append(
340-
nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype)
341-
)
359+
for _ in range(6):
360+
self.linears_tensor.append(
361+
nn.Linear(hidden_channels, hidden_channels, bias=False)
362+
)
342363
self.act = activation()
343364
self.equivariance_invariance_group = equivariance_invariance_group
344365
self.reset_parameters()
@@ -350,9 +371,10 @@ def reset_parameters(self):
350371
linear.reset_parameters()
351372

352373
def forward(self, X, edge_index, edge_weight, edge_attr):
374+
353375
C = self.cutoff(edge_weight)
354-
for i in range(len(self.linears_scalar)):
355-
edge_attr = self.act(self.linears_scalar[i](edge_attr))
376+
for linear_scalar in self.linears_scalar:
377+
edge_attr = self.act(linear_scalar(edge_attr))
356378
edge_attr = (edge_attr * C.view(-1, 1)).reshape(
357379
edge_attr.shape[0], self.hidden_channels, 3
358380
)
@@ -374,19 +396,17 @@ def forward(self, X, edge_index, edge_weight, edge_attr):
374396
if self.equivariance_invariance_group == "SO(3)":
375397
B = torch.matmul(Y, msg)
376398
I, A, S = decompose_tensor(2 * B)
377-
norm = tensor_norm(I + A + S)
378-
I = I / (norm + 1)[..., None, None]
379-
A = A / (norm + 1)[..., None, None]
380-
S = S / (norm + 1)[..., None, None]
399+
normp1 = (tensor_norm(I + A + S) + 1)[..., None, None]
400+
I, A, S = I / normp1, A / normp1, S / normp1
381401
I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
382402
A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
383403
S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
384404
dX = I + A + S
385-
dX = dX + torch.matmul(dX, dX)
386-
X = X + dX
405+
X = X + dX + dX**2
387406
return X
388407

389408
def message(self, I_j, A_j, S_j, edge_attr):
409+
390410
I, A, S = new_radial_tensor(
391411
I_j, A_j, S_j, edge_attr[..., 0], edge_attr[..., 1], edge_attr[..., 2]
392412
)
@@ -399,6 +419,7 @@ def aggregate(
399419
ptr: Optional[torch.Tensor],
400420
dim_size: Optional[int],
401421
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
422+
402423
I, A, S = features
403424
I = scatter(I, index, dim=self.node_dim, dim_size=dim_size)
404425
A = scatter(A, index, dim=self.node_dim, dim_size=dim_size)

torchmdnet/module.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from torch.optim import AdamW
33
from torch.optim.lr_scheduler import ReduceLROnPlateau
44
from torch.nn.functional import mse_loss, l1_loss
5+
from torch import Tensor
6+
from typing import Optional, Dict, Tuple
57

68
from pytorch_lightning import LightningModule
79
from torchmdnet.models.model import create_model, load_model
@@ -55,7 +57,15 @@ def configure_optimizers(self):
5557
}
5658
return [optimizer], [lr_scheduler]
5759

58-
def forward(self, z, pos, batch=None, q=None, s=None, extra_args=None):
60+
def forward(self,
61+
z: Tensor,
62+
pos: Tensor,
63+
batch: Optional[Tensor] = None,
64+
q: Optional[Tensor] = None,
65+
s: Optional[Tensor] = None,
66+
extra_args: Optional[Dict[str, Tensor]] = None
67+
) -> Tuple[Tensor, Optional[Tensor]]:
68+
5969
return self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args)
6070

6171
def training_step(self, batch, batch_idx):

torchmdnet/priors/atomref.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from torchmdnet.priors.base import BasePrior
2+
from typing import Optional, Dict
23
import torch
3-
from torch import nn
4+
from torch import nn, Tensor
45
from pytorch_lightning.utilities import rank_zero_warn
56

67

@@ -37,5 +38,5 @@ def reset_parameters(self):
3738
def get_init_args(self):
3839
return dict(max_z=self.initial_atomref.size(0))
3940

40-
def pre_reduce(self, x, z, pos, batch, extra_args):
41+
def pre_reduce(self, x: Tensor, z: Tensor, pos: Tensor, batch: Tensor, extra_args: Optional[Dict[str, Tensor]]):
4142
return x + self.atomref(z)

0 commit comments

Comments
 (0)