Skip to content

Commit f12c4e8

Browse files
rusty1sJakub Pietrak
authored andcommitted
Check ONNX output equality (pyg-team#5997)
1 parent 486597c commit f12c4e8

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888))
2020
- Added TorchScript support for `AttentiveFP `([#5868](https://github.com/pyg-team/pytorch_geometric/pull/5868))
2121
- Added `num_steps` argument to training and inference benchmarks ([#5898](https://github.com/pyg-team/pytorch_geometric/pull/5898))
22-
- Added `torch.onnx.export` support ([#5877](https://github.com/pyg-team/pytorch_geometric/pull/5877))
22+
- Added `torch.onnx.export` support ([#5877](https://github.com/pyg-team/pytorch_geometric/pull/5877), [#5997](https://github.com/pyg-team/pytorch_geometric/pull/5997))
2323
- Enable VTune ITT in inference and training benchmarks ([#5830](https://github.com/pyg-team/pytorch_geometric/pull/5830), [#5878](https://github.com/pyg-team/pytorch_geometric/pull/5878))
2424
- Add training benchmark ([#5774](https://github.com/pyg-team/pytorch_geometric/pull/5774))
2525
- Added a "Link Prediction on MovieLens" Colab notebook ([#5823](https://github.com/pyg-team/pytorch_geometric/pull/5823))

test/nn/models/test_basic_gnn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,12 @@ def forward(self, x, edge_index):
184184

185185
model = MyModel()
186186
x = torch.randn(3, 8)
187-
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
187+
edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])
188+
expected = model(x, edge_index)
189+
assert expected.size() == (3, 16)
188190

189191
torch.onnx.export(model, (x, edge_index), 'model.onnx',
190-
input_names=('x', 'edge_index'))
192+
input_names=('x', 'edge_index'), opset_version=16)
191193

192194
model = onnx.load('model.onnx')
193195
onnx.checker.check_model(model)
@@ -198,6 +200,7 @@ def forward(self, x, edge_index):
198200
'x': x.numpy(),
199201
'edge_index': edge_index.numpy()
200202
})[0]
201-
assert out.shape == (3, 16)
203+
out = torch.from_numpy(out)
204+
assert torch.allclose(out, expected, atol=1e-6)
202205

203206
os.remove('model.onnx')

0 commit comments

Comments
 (0)