Skip to content

Commit 5e75061

Browse files
authored
Support np.memmap in NeighborLoader (#5696)
1 parent 52ce0e4 commit 5e75061

File tree

5 files changed

+89
-32
lines changed

5 files changed

+89
-32
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.2.0] - 2022-MM-DD
77
### Added
8+
- Added `np.memmap` support in `NeighborLoader` ([#5696](https://github.com/pyg-team/pytorch_geometric/pull/5696))
89
- Added `assortativity` that computes degree assortativity coefficient ([#5587](https://github.com/pyg-team/pytorch_geometric/pull/5587))
910
- Added `SSGConv` layer ([#5599](https://github.com/pyg-team/pytorch_geometric/pull/5599))
1011
- Added `shuffle_node`, `mask_feature` and `add_random_edge` augmentation methdos ([#5548](https://github.com/pyg-team/pytorch_geometric/pull/5548))

test/data/test_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def test_pickling():
196196
assert id(batch._store._parent()) == id(batch)
197197
assert batch.num_nodes == 20
198198

199-
path = f'{random.randrange(sys.maxsize)}.pt'
199+
path = os.path.join(os.sep, 'tmp', f'{random.randrange(sys.maxsize)}.pt')
200200
torch.save(batch, path)
201201
assert id(batch._store._parent()) == id(batch)
202202
assert batch.num_nodes == 20

test/loader/test_neighbor_loader.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import os
2+
import random
3+
import sys
4+
15
import numpy as np
26
import pytest
37
import torch
@@ -516,3 +520,25 @@ def test_pyg_lib_heterogeneous_neighbor_loader():
516520
assert len(edge_id1_dict) == len(edge_id2_dict)
517521
for key in edge_id1_dict.keys():
518522
assert torch.equal(edge_id1_dict[key], edge_id2_dict[key])
523+
524+
525+
def test_memmap_neighbor_loader():
526+
path = os.path.join(os.sep, 'tmp', f'{random.randrange(sys.maxsize)}.npy')
527+
x = np.memmap(path, dtype=np.float32, mode='w+', shape=(100, 32))
528+
x[:] = np.random.randn(100, 32)
529+
530+
data = Data()
531+
data.x = np.memmap(path, dtype=np.float32, mode='r', shape=(100, 32))
532+
data.edge_index = get_edge_index(100, 100, 500)
533+
534+
assert str(data) == 'Data(x=[100, 32], edge_index=[2, 500])'
535+
assert data.num_nodes == 100
536+
537+
loader = NeighborLoader(data, num_neighbors=[5] * 2, batch_size=20,
538+
num_workers=6)
539+
batch = next(iter(loader))
540+
assert batch.num_nodes <= 100
541+
assert isinstance(batch.x, torch.Tensor)
542+
assert batch.x.size() == (batch.num_nodes, 32)
543+
544+
os.remove(path)

torch_geometric/data/storage.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Union,
1616
)
1717

18+
import numpy as np
1819
import torch
1920
from torch import Tensor
2021
from torch_sparse import SparseTensor, coalesce
@@ -265,8 +266,10 @@ def num_nodes(self) -> Optional[int]:
265266
if 'num_nodes' in self:
266267
return self['num_nodes']
267268
for key, value in self.items():
268-
if isinstance(value, Tensor) and (key in N_KEYS or 'node' in key):
269-
return value.size(self._parent().__cat_dim__(key, value, self))
269+
if (isinstance(value, (Tensor, np.ndarray))
270+
and (key in N_KEYS or 'node' in key)):
271+
cat_dim = self._parent().__cat_dim__(key, value, self)
272+
return value.shape[cat_dim]
270273
if 'adj' in self and isinstance(self.adj, SparseTensor):
271274
return self.adj.size(0)
272275
if 'adj_t' in self and isinstance(self.adj_t, SparseTensor):
@@ -291,7 +294,9 @@ def num_nodes(self) -> Optional[int]:
291294

292295
@property
293296
def num_node_features(self) -> int:
294-
if 'x' in self and isinstance(self.x, (Tensor, SparseTensor)):
297+
if 'x' in self and isinstance(self.x, (Tensor, np.ndarray)):
298+
return 1 if self.x.ndim == 1 else self.x.shape[-1]
299+
if 'x' in self and isinstance(self.x, SparseTensor):
295300
return 1 if self.x.dim() == 1 else self.x.size(-1)
296301
return 0
297302

@@ -302,9 +307,9 @@ def num_features(self) -> int:
302307
def is_node_attr(self, key: str) -> bool:
303308
value = self[key]
304309
cat_dim = self._parent().__cat_dim__(key, value, self)
305-
if not isinstance(value, Tensor):
310+
if not isinstance(value, (Tensor, np.ndarray)):
306311
return False
307-
if value.dim() == 0 or value.size(cat_dim) != self.num_nodes:
312+
if value.ndim == 0 or value.shape[cat_dim] != self.num_nodes:
308313
return False
309314
return True
310315

@@ -350,17 +355,19 @@ def edge_index(self) -> Tensor:
350355
def num_edges(self) -> int:
351356
# We sequentially access attributes that reveal the number of edges.
352357
for key, value in self.items():
353-
if isinstance(value, Tensor) and 'edge' in key:
354-
return value.size(self._parent().__cat_dim__(key, value, self))
358+
if isinstance(value, (Tensor, np.ndarray)) and 'edge' in key:
359+
cat_dim = self._parent().__cat_dim__(key, value, self)
360+
return value.shape[cat_dim]
355361
for value in self.values('adj', 'adj_t'):
356362
if isinstance(value, SparseTensor):
357363
return value.nnz()
358364
return 0
359365

360366
@property
361367
def num_edge_features(self) -> int:
362-
if 'edge_attr' in self and isinstance(self.edge_attr, Tensor):
363-
return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(-1)
368+
if ('edge_attr' in self and isinstance(self.edge_attr,
369+
(Tensor, np.ndarray))):
370+
return 1 if self.edge_attr.ndim == 1 else self.edge_attr.shape[-1]
364371
return 0
365372

366373
@property
@@ -386,9 +393,9 @@ def is_node_attr(self, key: str) -> bool:
386393
def is_edge_attr(self, key: str) -> bool:
387394
value = self[key]
388395
cat_dim = self._parent().__cat_dim__(key, value, self)
389-
if not isinstance(value, Tensor):
396+
if not isinstance(value, (Tensor, np.ndarray)):
390397
return False
391-
if value.dim() == 0 or value.size(cat_dim) != self.num_edges:
398+
if value.ndim == 0 or value.shape[cat_dim] != self.num_edges:
392399
return False
393400
return True
394401

@@ -467,9 +474,9 @@ def is_node_attr(self, key: str) -> bool:
467474
cat_dim = self._parent().__cat_dim__(key, value, self)
468475

469476
num_nodes, num_edges = self.num_nodes, self.num_edges
470-
if not isinstance(value, Tensor):
477+
if not isinstance(value, (Tensor, np.ndarray)):
471478
return False
472-
if value.dim() == 0 or value.size(cat_dim) != num_nodes:
479+
if value.ndim == 0 or value.shape[cat_dim] != num_nodes:
473480
return False
474481
if num_nodes != num_edges:
475482
return True
@@ -480,9 +487,9 @@ def is_edge_attr(self, key: str) -> bool:
480487
cat_dim = self._parent().__cat_dim__(key, value, self)
481488

482489
num_nodes, num_edges = self.num_nodes, self.num_edges
483-
if not isinstance(value, Tensor):
490+
if not isinstance(value, (Tensor, np.ndarray)):
484491
return False
485-
if value.dim() == 0 or value.size(cat_dim) != num_edges:
492+
if value.ndim == 0 or value.shape[cat_dim] != num_edges:
486493
return False
487494
if num_nodes != num_edges:
488495
return True

torch_geometric/loader/utils.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Sequence
44
from typing import Dict, Optional, Tuple, Union
55

6+
import numpy as np
67
import torch
78
from torch import Tensor
89
from torch_sparse import SparseTensor
@@ -11,20 +12,34 @@
1112
from torch_geometric.data.feature_store import FeatureStore, TensorAttr
1213
from torch_geometric.data.graph_store import GraphStore
1314
from torch_geometric.data.storage import EdgeStorage, NodeStorage
14-
from torch_geometric.typing import InputEdges, InputNodes, OptTensor
15+
from torch_geometric.typing import (
16+
FeatureTensorType,
17+
InputEdges,
18+
InputNodes,
19+
OptTensor,
20+
)
1521

1622

17-
def index_select(value: Tensor, index: Tensor, dim: int = 0) -> Tensor:
18-
out: Optional[Tensor] = None
19-
if torch.utils.data.get_worker_info() is not None:
20-
# If we are in a background process, we write directly into a shared
21-
# memory tensor to avoid an extra copy:
22-
size = list(value.size())
23-
size[dim] = index.numel()
24-
numel = math.prod(size)
25-
storage = value.storage()._new_shared(numel)
26-
out = value.new(storage).view(size)
27-
return torch.index_select(value, dim, index, out=out)
23+
def index_select(value: FeatureTensorType, index: Tensor,
24+
dim: int = 0) -> Tensor:
25+
if isinstance(value, Tensor):
26+
out: Optional[Tensor] = None
27+
if torch.utils.data.get_worker_info() is not None:
28+
# If we are in a background process, we write directly into a
29+
# shared memory tensor to avoid an extra copy:
30+
size = list(value.shape)
31+
size[dim] = index.numel()
32+
numel = math.prod(size)
33+
storage = value.storage()._new_shared(numel)
34+
out = value.new(storage).view(size)
35+
36+
return torch.index_select(value, dim, index, out=out)
37+
38+
elif isinstance(value, np.ndarray):
39+
return torch.from_numpy(np.take(value, index, axis=dim))
40+
41+
raise ValueError(f"Encountered invalid feature tensor type "
42+
f"(got '{type(value)}')")
2843

2944

3045
def filter_node_store_(store: NodeStorage, out_store: NodeStorage,
@@ -35,7 +50,10 @@ def filter_node_store_(store: NodeStorage, out_store: NodeStorage,
3550
out_store.num_nodes = index.numel()
3651

3752
elif store.is_node_attr(key):
38-
index = index.to(value.device)
53+
if isinstance(value, Tensor):
54+
index = index.to(value.device)
55+
elif isinstance(value, np.ndarray):
56+
index = index.cpu()
3957
dim = store._parent().__cat_dim__(key, value, store)
4058
out_store[key] = index_select(value, index, dim=dim)
4159

@@ -69,12 +87,17 @@ def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,
6987

7088
elif store.is_edge_attr(key):
7189
dim = store._parent().__cat_dim__(key, value, store)
72-
if perm is None:
90+
if isinstance(value, Tensor):
7391
index = index.to(value.device)
92+
elif isinstance(value, np.ndarray):
93+
index = index.cpu()
94+
if perm is None:
7495
out_store[key] = index_select(value, index, dim=dim)
7596
else:
76-
perm = perm.to(value.device)
77-
index = index.to(value.device)
97+
if isinstance(value, Tensor):
98+
perm = perm.to(value.device)
99+
elif isinstance(value, np.ndarray):
100+
perm = perm.cpu()
78101
out_store[key] = index_select(value, perm[index], dim=dim)
79102

80103
return store

0 commit comments

Comments
 (0)