Skip to content

Commit e998348

Browse files
authored
NeighborSampler: Sort local neighborhoods according to time (#5516)
* update * update * update * update doc-string * changelog * update * update * update
1 parent 9b4b854 commit e998348

File tree

7 files changed

+106
-42
lines changed

7 files changed

+106
-42
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
3232
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
3333
### Changed
34+
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516))
3435
- Fixed a bug when applying several scalers with `PNAConv` ([#5514](https://github.com/pyg-team/pytorch_geometric/issues/5514))
3536
- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494))
3637
- Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490))

test/loader/test_link_neighbor_loader.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from torch_geometric.data import Data, HeteroData
55
from torch_geometric.loader import LinkNeighborLoader
6-
from torch_geometric.testing import withPackage
76
from torch_geometric.testing.feature_store import MyFeatureStore
87
from torch_geometric.testing.graph_store import MyGraphStore
98

@@ -182,7 +181,6 @@ def test_link_neighbor_loader_edge_label():
182181
assert torch.all(batch.edge_label[10:] == 0)
183182

184183

185-
@withPackage('torch_sparse>=0.6.14')
186184
def test_temporal_heterogeneous_link_neighbor_loader():
187185
data = HeteroData()
188186

test/loader/test_neighbor_loader.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,13 @@ def forward(self, x, edge_index, edge_weight):
280280
assert torch.allclose(out1, out2, atol=1e-6)
281281

282282

283-
@withPackage('torch_sparse>=0.6.14')
284283
def test_temporal_heterogeneous_neighbor_loader_on_cora(get_dataset):
285284
dataset = get_dataset(name='Cora')
286285
data = dataset[0]
287286

288287
hetero_data = HeteroData()
289288
hetero_data['paper'].x = data.x
290-
hetero_data['paper'].time = torch.arange(data.num_nodes)
289+
hetero_data['paper'].time = torch.arange(data.num_nodes, 0, -1)
291290
hetero_data['paper', 'paper'].edge_index = data.edge_index
292291

293292
loader = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
@@ -381,37 +380,57 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
381380
'author', 'to', 'paper'].edge_index.size())
382381

383382

384-
@withPackage('torch_sparse>=0.6.14')
385383
@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
386384
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
387385
def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore,
388386
GraphStore):
389387
# Initialize dataset (once):
390388
dataset = get_dataset(name='Cora')
391389
data = dataset[0]
390+
data.time = torch.arange(data.num_nodes, 0, -1)
392391

393392
# Initialize feature store, graph store, and reference:
394393
feature_store = FeatureStore()
395394
graph_store = GraphStore()
396395
hetero_data = HeteroData()
397396

398-
feature_store.put_tensor(data.x, group_name='paper', attr_name='x',
399-
index=None)
397+
feature_store.put_tensor(
398+
data.x,
399+
group_name='paper',
400+
attr_name='x',
401+
index=None,
402+
)
400403
hetero_data['paper'].x = data.x
401404

402-
feature_store.put_tensor(torch.arange(data.num_nodes), group_name='paper',
403-
attr_name='time', index=None)
404-
hetero_data['paper'].time = torch.arange(data.num_nodes)
405-
406-
num_nodes = data.x.size(dim=0)
407-
graph_store.put_edge_index(edge_index=data.edge_index,
408-
edge_type=('paper', 'to', 'paper'),
409-
layout='coo', size=(num_nodes, num_nodes))
405+
feature_store.put_tensor(
406+
data.time,
407+
group_name='paper',
408+
attr_name='time',
409+
index=None,
410+
)
411+
hetero_data['paper'].time = data.time
412+
413+
# Sort according to time in local neighborhoods:
414+
row, col = data.edge_index
415+
perm = ((col * (data.num_nodes + 1)) + data.time[row]).argsort()
416+
edge_index = data.edge_index[:, perm]
417+
418+
graph_store.put_edge_index(
419+
edge_index,
420+
edge_type=('paper', 'to', 'paper'),
421+
layout='coo',
422+
is_sorted=True,
423+
size=(data.num_nodes, data.num_nodes),
424+
)
410425
hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index
411426

412-
loader1 = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
413-
input_nodes='paper', time_attr='time',
414-
batch_size=128)
427+
loader1 = NeighborLoader(
428+
hetero_data,
429+
num_neighbors=[-1, -1],
430+
input_nodes='paper',
431+
time_attr='time',
432+
batch_size=128,
433+
)
415434

416435
loader2 = NeighborLoader(
417436
(feature_store, graph_store),

torch_geometric/loader/link_neighbor_loader.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,11 @@ class LinkNeighborLoader(LinkLoader):
124124
a sampled mini-batch and returns a transformed version.
125125
(default: :obj:`None`)
126126
is_sorted (bool, optional): If set to :obj:`True`, assumes that
127-
:obj:`edge_index` is sorted by column. This avoids internal
128-
re-sorting of the data and can improve runtime and memory
129-
efficiency. (default: :obj:`False`)
127+
:obj:`edge_index` is sorted by column.
128+
If :obj:`time_attr` is set, additionally requires that rows are
129+
sorted according to time within individual neighborhoods.
130+
This avoids internal re-sorting of the data and can improve
131+
runtime and memory efficiency. (default: :obj:`False`)
130132
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
131133
the returning data in each worker's subprocess rather than in the
132134
main process.

torch_geometric/loader/neighbor_loader.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,11 @@ class NeighborLoader(NodeLoader):
135135
a sampled mini-batch and returns a transformed version.
136136
(default: :obj:`None`)
137137
is_sorted (bool, optional): If set to :obj:`True`, assumes that
138-
:obj:`edge_index` is sorted by column. This avoids internal
139-
re-sorting of the data and can improve runtime and memory
140-
efficiency. (default: :obj:`False`)
138+
:obj:`edge_index` is sorted by column.
139+
If :obj:`time_attr` is set, additionally requires that rows are
140+
sorted according to time within individual neighborhoods.
141+
This avoids internal re-sorting of the data and can improve
142+
runtime and memory efficiency. (default: :obj:`False`)
141143
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
142144
the returning data in each worker's subprocess rather than in the
143145
main process.

torch_geometric/sampler/neighbor_sampler.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from torch_geometric.data import Data, HeteroData, remote_backend_utils
66
from torch_geometric.data.feature_store import FeatureStore
7-
from torch_geometric.data.graph_store import GraphStore
7+
from torch_geometric.data.graph_store import EdgeLayout, GraphStore
88
from torch_geometric.sampler.base import (
99
BaseSampler,
1010
EdgeSamplerInput,
@@ -80,7 +80,7 @@ def __init__(
8080

8181
# Convert the graph data into a suitable format for sampling.
8282
out = to_csc(data, device='cpu', share_memory=share_memory,
83-
is_sorted=is_sorted)
83+
is_sorted=is_sorted, src_node_time=self.node_time)
8484
self.colptr, self.row, self.perm = out
8585
assert isinstance(num_neighbors, (list, tuple))
8686

@@ -99,7 +99,8 @@ def __init__(
9999

100100
# Obtain CSC representations for in-memory sampling:
101101
out = to_hetero_csc(data, device='cpu', share_memory=share_memory,
102-
is_sorted=is_sorted)
102+
is_sorted=is_sorted,
103+
node_time_dict=self.node_time_dict)
103104
colptr_dict, row_dict, perm_dict = out
104105

105106
# Conversions to/from C++ string type:
@@ -125,16 +126,34 @@ def __init__(
125126
# TODO support `FeatureStore` with no edge types (e.g. `Data`)
126127
feature_store, graph_store = data
127128

129+
# Obtain all node and edge metadata:
130+
node_attrs = feature_store.get_all_tensor_attrs()
131+
edge_attrs = graph_store.get_all_edge_attrs()
132+
128133
# TODO support `collect` on `FeatureStore`:
129134
self.node_time_dict = None
130135
if time_attr is not None:
136+
# If the `time_attr` is present, we expect that `GraphStore`
137+
# holds all edges sorted by destination, and within local
138+
# neighborhoods, node indices should be sorted by time.
139+
# TODO (matthias, manan) Find an alternative way to ensure
140+
for edge_attr in edge_attrs:
141+
if edge_attr.layout == EdgeLayout.CSR:
142+
raise ValueError(
143+
"Temporal sampling requires that edges are stored "
144+
"in either COO or CSC layout")
145+
if not edge_attr.is_sorted:
146+
raise ValueError(
147+
"Temporal sampling requires that edges are "
148+
"sorted by destination, and by source time "
149+
"within local neighborhoods")
150+
131151
# We need to obtain all features with 'attr_name=time_attr'
132152
# from the feature store and store them in node_time_dict. To
133153
# do so, we make an explicit feature store GET call here with
134154
# the relevant 'TensorAttr's
135155
time_attrs = [
136-
attr for attr in feature_store.get_all_tensor_attrs()
137-
if attr.attr_name == time_attr
156+
attr for attr in node_attrs if attr.attr_name == time_attr
138157
]
139158
for attr in time_attrs:
140159
attr.index = None
@@ -144,10 +163,6 @@ def __init__(
144163
for time_attr, time_tensor in zip(time_attrs, time_tensors)
145164
}
146165

147-
# Obtain all node and edge metadata:
148-
node_attrs = feature_store.get_all_tensor_attrs()
149-
edge_attrs = graph_store.get_all_edge_attrs()
150-
151166
self.node_types = list(
152167
set(node_attr.group_name for node_attr in node_attrs))
153168
self.edge_types = list(

torch_geometric/sampler/utils.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,37 @@
77

88
from torch_geometric.data import Data, HeteroData
99
from torch_geometric.data.storage import EdgeStorage
10-
from torch_geometric.typing import EdgeType, OptTensor
10+
from torch_geometric.typing import EdgeType, NodeType, OptTensor
1111

1212
# Edge Layout Conversion ######################################################
1313

1414

15+
def sort_csc(
16+
row: Tensor,
17+
col: Tensor,
18+
src_node_time: OptTensor = None,
19+
) -> Tuple[Tensor, Tensor, Tensor]:
20+
if src_node_time is None:
21+
col, perm = col.sort()
22+
return row[perm], col, perm
23+
else:
24+
# Multiplying by raw `datetime[64ns]` values may cause overflows.
25+
# As such, we normalize time into range [0, 1) before sorting:
26+
src_node_time = src_node_time.to(torch.double, copy=True)
27+
min_time, max_time = src_node_time.min(), src_node_time.max() + 1
28+
src_node_time.sub_(min_time).div_(max_time)
29+
30+
perm = src_node_time[row].add_(col.to(torch.double)).argsort()
31+
return row[perm], col[perm], perm
32+
33+
1534
# TODO(manan) deprecate when FeatureStore / GraphStore unification is complete
1635
def to_csc(
1736
data: Union[Data, EdgeStorage],
1837
device: Optional[torch.device] = None,
1938
share_memory: bool = False,
2039
is_sorted: bool = False,
40+
src_node_time: Optional[Tensor] = None,
2141
) -> Tuple[Tensor, Tensor, OptTensor]:
2242
# Convert the graph data into a suitable format for sampling (CSC format).
2343
# Returns the `colptr` and `row` indices of the graph, as well as an
@@ -27,17 +47,23 @@ def to_csc(
2747
perm: Optional[Tensor] = None
2848

2949
if hasattr(data, 'adj'):
50+
if src_node_time is not None:
51+
raise NotImplementedError("Temporal sampling via 'SparseTensor' "
52+
"format not yet supported")
3053
colptr, row, _ = data.adj.csc()
3154

3255
elif hasattr(data, 'adj_t'):
56+
if src_node_time is not None:
57+
raise NotImplementedError("Temporal sampling via 'SparseTensor' "
58+
"format not yet supported")
3359
colptr, row, _ = data.adj_t.csr()
3460

3561
elif data.edge_index is not None:
36-
(row, col) = data.edge_index
62+
row, col = data.edge_index
3763
if not is_sorted:
38-
perm = (col * data.size(0)).add_(row).argsort()
39-
row = row[perm]
40-
colptr = torch.ops.torch_sparse.ind2ptr(col[perm], data.size(1))
64+
row, col, perm = sort_csc(row, col, src_node_time)
65+
colptr = torch.ops.torch_sparse.ind2ptr(col, data.size(1))
66+
4167
else:
4268
row = torch.empty(0, dtype=torch.long, device=device)
4369
colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long,
@@ -61,17 +87,18 @@ def to_hetero_csc(
6187
device: Optional[torch.device] = None,
6288
share_memory: bool = False,
6389
is_sorted: bool = False,
90+
node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
6491
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
6592
# Convert the heterogeneous graph data into a suitable format for sampling
6693
# (CSC format).
6794
# Returns dictionaries holding `colptr` and `row` indices as well as edge
6895
# permutations for each edge type, respectively.
6996
colptr_dict, row_dict, perm_dict = {}, {}, {}
7097

71-
for store in data.edge_stores:
72-
key = store._key
73-
out = to_csc(store, device, share_memory, is_sorted)
74-
colptr_dict[key], row_dict[key], perm_dict[key] = out
98+
for edge_type, store in data.edge_items():
99+
src_node_time = (node_time_dict or {}).get(edge_type[0], None)
100+
out = to_csc(store, device, share_memory, is_sorted, src_node_time)
101+
colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out
75102

76103
return colptr_dict, row_dict, perm_dict
77104

0 commit comments

Comments
 (0)