Skip to content

Commit 5938e97

Browse files
authored
Fixing TorchScript support in BatchNorm (#5614)
* update * update * changelog
1 parent 7029bad commit 5938e97

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

.github/workflows/full_testing.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ jobs:
4040
4141
- name: Install internal dependencies
4242
run: |
43+
pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${{ matrix.torch-version }}+cpu.html
4344
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
4445
4546
- name: Install main package

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4040
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
4141
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603))
4242
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
43-
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530))
43+
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
4444
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
4545
- Fixed a bug when applying several scalers with `PNAConv` ([#5514](https://github.com/pyg-team/pytorch_geometric/issues/5514))
4646
- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494))

torch_geometric/nn/norm/batch_norm.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,16 @@ def reset_parameters(self):
5656
def forward(self, x: Tensor) -> Tensor:
5757
""""""
5858
if self.allow_single_element and x.size(0) <= 1:
59-
training = self.module.training
60-
self.module.eval()
61-
out = self.module(x)
62-
self.module.training = training
63-
return out
59+
return torch.nn.functional.batch_norm(
60+
x,
61+
self.module.running_mean,
62+
self.module.running_var,
63+
self.module.weight,
64+
self.module.bias,
65+
False, # bn_training
66+
0.0, # momentum
67+
self.module.eps,
68+
)
6469
return self.module(x)
6570

6671
def __repr__(self):

0 commit comments

Comments
 (0)