Skip to content

Commit cd72c22

Browse files
authored
BaseStorage.get() functionality (#5240)
* update * changelog
1 parent 873cd56 commit cd72c22

File tree

4 files changed

+12
-0
lines changed

4 files changed

+12
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.2.0] - 2022-MM-DD
77
### Added
8+
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
89
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
910
### Changed
1011
### Removed

test/data/test_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def test_data():
2222

2323
assert data.x.tolist() == x.tolist()
2424
assert data['x'].tolist() == x.tolist()
25+
assert data.get('x').tolist() == x.tolist()
26+
assert data.get('y', 2) == 2
27+
assert data.get('y', None) is None
2528

2629
assert sorted(data.keys) == ['edge_index', 'x']
2730
assert len(data) == 2

test/data/test_storage.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def test_base_storage():
1313
assert storage.x is not None
1414
assert storage.y is not None
1515

16+
assert torch.allclose(storage.get('x', None), storage.x)
17+
assert torch.allclose(storage.get('y', None), storage.y)
18+
assert storage.get('z', 2) == 2
19+
assert storage.get('z', None) is None
20+
1621
assert len(list(storage.keys('x', 'y', 'z'))) == 2
1722
assert len(list(storage.keys('x', 'y', 'z'))) == 2
1823
assert len(list(storage.values('x', 'y', 'z'))) == 2

torch_geometric/data/storage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def apply(self, func: Callable, *args: List[str]):
161161

162162
# Additional functionality ################################################
163163

164+
def get(self, key: str, value: Optional[Any] = None) -> Any:
165+
return self._mapping.get(key, value)
166+
164167
def to_dict(self) -> Dict[str, Any]:
165168
r"""Returns a dictionary of stored key/value pairs."""
166169
return copy.copy(self._mapping)

0 commit comments

Comments
 (0)