Skip to content

Commit fc32568

Browse files
authored
Test __inc__ with nested tensors (#7647)
1 parent b60fa62 commit fc32568

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10-
- Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643))
10+
- Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643), [#7647](https://github.com/pyg-team/pytorch_geometric/pull/7647))
1111
- Added `interval` argument to `Cartesian` and `LocalCartesian` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614))
1212
- Added a `LightGCN` example on the `AmazonBook` dataset ([7603](https://github.com/pyg-team/pytorch_geometric/pull/7603))
1313
- Added a tutorial on hierarchical neighborhood sampling ([#7594](https://github.com/pyg-team/pytorch_geometric/pull/7594))

test/data/test_batch.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -521,16 +521,23 @@ def test_torch_sparse_batch(layout):
521521

522522

523523
def test_torch_nested_batch():
524+
class MyData(Data):
525+
def __inc__(self, key, value, *args, **kwargs) -> int:
526+
return 2
527+
524528
x1 = nested_tensor([torch.randn(3), torch.randn(4)])
525-
data1 = Data(x=x1)
526-
assert str(data1) == 'Data(x=[2, 4])'
529+
data1 = MyData(x=x1)
530+
assert str(data1) == 'MyData(x=[2, 4])'
527531

528532
x2 = nested_tensor([torch.randn(3), torch.randn(4), torch.randn(5)])
529-
data2 = Data(x=x2)
530-
assert str(data2) == 'Data(x=[3, 5])'
533+
data2 = MyData(x=x2)
534+
assert str(data2) == 'MyData(x=[3, 5])'
531535

532536
batch = Batch.from_data_list([data1, data2])
533-
assert str(batch) == 'DataBatch(x=[5, 5], batch=[5], ptr=[3])'
537+
assert str(batch) == 'MyDataBatch(x=[5, 5], batch=[5], ptr=[3])'
534538

535-
x = nested_tensor(list(x1.unbind() + x2.unbind())).to_padded_tensor(0.0)
536-
assert torch.equal(batch.x.to_padded_tensor(0.0), x)
539+
expected = nested_tensor(list(x1.unbind() + (x2 + 2).unbind()))
540+
assert torch.equal(
541+
batch.x.to_padded_tensor(0.0),
542+
expected.to_padded_tensor(0.0),
543+
)

0 commit comments

Comments
 (0)