Skip to content

Commit 8076e90

Browse files
rusty1sJakub Pietrak
authored andcommitted
Support for input_time in NeighborLoader (pyg-team#5763)
1 parent 45d30ca commit 8076e90

14 files changed

+178
-193
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.2.0] - 2022-MM-DD
77
### Added
8+
- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763))
89
- Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717))
910
- Added `HeteroData` support for `transforms.Constant` ([#5700](https://github.com/pyg-team/pytorch_geometric/pull/5700))
1011
- Added `np.memmap` support in `NeighborLoader` ([#5696](https://github.com/pyg-team/pytorch_geometric/pull/5696))

test/loader/test_hgt_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def test_hgt_loader():
6060
assert set(batch.node_types) == {'paper', 'author'}
6161
assert set(batch.edge_types) == set(data.edge_types)
6262

63-
assert len(batch['paper']) == 2
63+
assert len(batch['paper']) == 3
6464
assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5
65+
assert batch['paper'].input_nodes.numel() == batch_size
6566
assert batch['paper'].batch_size == batch_size
6667
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100
6768

test/loader/test_link_neighbor_loader.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
5151
for batch in loader:
5252
assert isinstance(batch, Data)
5353

54-
assert len(batch) == 5
54+
assert len(batch) == 6
5555
assert batch.x.size(0) <= 100
5656
assert batch.x.min() >= 0 and batch.x.max() < 100
57+
assert batch.input_links.numel() == 20
5758
assert batch.edge_index.min() >= 0
5859
assert batch.edge_index.max() < batch.num_nodes
5960
assert batch.edge_attr.min() >= 0
@@ -110,7 +111,7 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
110111

111112
for batch in loader:
112113
assert isinstance(batch, HeteroData)
113-
assert len(batch) == 5
114+
assert len(batch) == 6
114115
if neg_sampling_ratio == 0.0:
115116
# Assert only positive samples are present in the original graph:
116117
assert batch['paper', 'author'].edge_label.sum() == 0
@@ -120,7 +121,6 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
120121
assert len(edge_index | edge_label_index) == len(edge_index)
121122

122123
else:
123-
124124
assert batch['paper', 'author'].edge_label_index.size(1) == 40
125125
assert torch.all(batch['paper', 'author'].edge_label[:20] == 1)
126126
assert torch.all(batch['paper', 'author'].edge_label[20:] == 0)
@@ -195,7 +195,7 @@ def test_temporal_heterogeneous_link_neighbor_loader():
195195
data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000)
196196
data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)
197197

198-
with pytest.raises(ValueError, match=r"'edge_label_time' was not set.*"):
198+
with pytest.raises(ValueError, match=r"'edge_label_time' is not set"):
199199
loader = LinkNeighborLoader(
200200
data,
201201
num_neighbors=[-1] * 2,
@@ -312,7 +312,8 @@ def test_homogeneous_link_neighbor_loader_no_edges():
312312

313313
for batch in loader:
314314
assert isinstance(batch, Data)
315-
assert len(batch) == 3
315+
assert len(batch) == 4
316+
assert batch.input_links.numel() == 20
316317
assert batch.num_nodes <= 40
317318
assert batch.edge_label_index.size(1) == 20
318319
assert batch.num_nodes == batch.edge_label_index.unique().numel()
@@ -328,8 +329,9 @@ def test_heterogeneous_link_neighbor_loader_no_edges():
328329

329330
for batch in loader:
330331
assert isinstance(batch, HeteroData)
331-
assert len(batch) == 3
332+
assert len(batch) == 4
332333
assert batch['paper'].num_nodes <= 40
334+
assert batch['paper', 'paper'].input_links.numel() == 20
333335
assert batch['paper', 'paper'].edge_label_index.size(1) == 20
334336
assert batch['paper'].num_nodes == batch[
335337
'paper', 'paper'].edge_label_index.unique().numel()

test/loader/test_neighbor_loader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ def test_homogeneous_neighbor_loader(directed):
4848

4949
for batch in loader:
5050
assert isinstance(batch, Data)
51-
52-
assert len(batch) == 4
51+
assert len(batch) == 5
5352
assert batch.x.size(0) <= 100
54-
assert batch.batch_size == 20
53+
assert batch.input_nodes.numel() == batch.batch_size == 20
5554
assert batch.x.min() >= 0 and batch.x.max() < 100
5655
assert batch.edge_index.min() >= 0
5756
assert batch.edge_index.max() < batch.num_nodes
@@ -118,8 +117,9 @@ def test_heterogeneous_neighbor_loader(directed):
118117
# Test node type selection:
119118
assert set(batch.node_types) == {'paper', 'author'}
120119

121-
assert len(batch['paper']) == 2
120+
assert len(batch['paper']) == 3
122121
assert batch['paper'].x.size(0) <= 100
122+
assert batch['paper'].input_nodes.numel() == batch_size
123123
assert batch['paper'].batch_size == batch_size
124124
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100
125125

@@ -498,7 +498,7 @@ def test_pyg_lib_heterogeneous_neighbor_loader():
498498
'author__to__paper': [-1, -1],
499499
}
500500

501-
sample = torch.ops.pyg.hetero_neighbor_sample_cpu
501+
sample = torch.ops.pyg.hetero_neighbor_sample
502502
out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,
503503
num_neighbors_dict, None, None, True, False, True, False,
504504
"uniform", True)

torch_geometric/data/lightning_datamodule.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,10 @@
88
from torch_geometric.data import Data, Dataset, HeteroData
99
from torch_geometric.data.feature_store import FeatureStore
1010
from torch_geometric.data.graph_store import GraphStore
11+
from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
1112
from torch_geometric.loader.dataloader import DataLoader
12-
from torch_geometric.loader.link_neighbor_loader import (
13-
LinkNeighborLoader,
14-
get_edge_label_index,
15-
)
16-
from torch_geometric.loader.neighbor_loader import (
17-
NeighborLoader,
18-
NeighborSampler,
19-
get_input_nodes,
20-
)
13+
from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes
14+
from torch_geometric.sampler import NeighborSampler
2115
from torch_geometric.typing import InputEdges, InputNodes
2216

2317
try:

torch_geometric/loader/hgt_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,18 @@ def __init__(
104104
**kwargs,
105105
):
106106
node_type, _ = get_input_nodes(data, input_nodes)
107-
node_sampler = HGTSampler(
107+
108+
hgt_sampler = HGTSampler(
108109
data,
109110
num_samples=num_samples,
110111
input_type=node_type,
111112
is_sorted=is_sorted,
112113
share_memory=kwargs.get('num_workers', 0) > 0,
113114
)
115+
114116
super().__init__(
115117
data=data,
116-
node_sampler=node_sampler,
118+
node_sampler=hgt_sampler,
117119
input_nodes=input_nodes,
118120
transform=transform,
119121
filter_per_worker=filter_per_worker,

torch_geometric/loader/link_loader.py

Lines changed: 40 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Iterator, Tuple, Union
1+
from typing import Any, Callable, Iterator, List, Tuple, Union
22

33
import torch
44

@@ -7,6 +7,7 @@
77
from torch_geometric.data.graph_store import GraphStore
88
from torch_geometric.loader.base import DataLoaderIterator
99
from torch_geometric.loader.utils import (
10+
InputData,
1011
filter_custom_store,
1112
filter_data,
1213
filter_hetero_data,
@@ -89,53 +90,57 @@ def __init__(
8990
if 'collate_fn' in kwargs:
9091
del kwargs['collate_fn']
9192

92-
self.data = data
93-
94-
# Initialize sampler with keyword arguments:
95-
# NOTE sampler is an attribute of 'DataLoader', so we use link_sampler
96-
# here:
97-
self.link_sampler = link_sampler
98-
99-
# Store additional arguments:
100-
self.edge_label = edge_label
101-
self.edge_label_index = edge_label_index
102-
self.edge_label_time = edge_label_time
103-
self.transform = transform
104-
self.filter_per_worker = filter_per_worker
105-
self.neg_sampling_ratio = neg_sampling_ratio
106-
107-
# Get input type, or None for homogeneous graphs:
93+
# Get edge type (or `None` for homogeneous graphs):
10894
edge_type, edge_label_index = get_edge_label_index(
10995
data, edge_label_index)
11096
if edge_label is None:
11197
edge_label = torch.zeros(edge_label_index.size(1),
11298
device=edge_label_index.device)
113-
self.input_type = edge_type
11499

115-
super().__init__(
116-
Dataset(edge_label_index, edge_label, edge_label_time),
117-
collate_fn=self.collate_fn,
118-
**kwargs,
100+
self.data = data
101+
self.edge_type = edge_type
102+
self.link_sampler = link_sampler
103+
self.input_data = InputData(edge_label_index[0], edge_label_index[1],
104+
edge_label, edge_label_time)
105+
self.neg_sampling_ratio = neg_sampling_ratio
106+
self.transform = transform
107+
self.filter_per_worker = filter_per_worker
108+
109+
iterator = range(edge_label_index.size(1))
110+
super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)
111+
112+
def collate_fn(self, index: List[int]) -> Any:
113+
r"""Samples a subgraph from a batch of input nodes."""
114+
input_data: EdgeSamplerInput = self.input_data[index]
115+
out = self.link_sampler.sample_from_edges(
116+
input_data,
117+
negative_sampling_ratio=self.neg_sampling_ratio,
119118
)
120119

120+
if self.filter_per_worker: # Execute `filter_fn` in the worker process
121+
out = self.filter_fn(out)
122+
123+
return out
124+
121125
def filter_fn(
122126
self,
123127
out: Union[SamplerOutput, HeteroSamplerOutput],
124128
) -> Union[Data, HeteroData]:
125129
r"""Joins the sampled nodes with their corresponding features,
126-
returning the resulting (Data or HeteroData) object to be used
127-
downstream."""
130+
returning the resulting :class:`~torch_geometric.data.Data` or
131+
:class:`~torch_geometric.data.HeteroData` object to be used downstream.
132+
"""
128133
if isinstance(out, SamplerOutput):
129-
edge_label_index, edge_label, edge_label_time = out.metadata
130134
data = filter_data(self.data, out.node, out.row, out.col, out.edge,
131135
self.link_sampler.edge_permutation)
136+
132137
data.batch = out.batch
133-
data.edge_label_index = edge_label_index
134-
data.edge_label = edge_label
135-
data.edge_label_time = edge_label_time
138+
data.input_links = out.metadata[0]
139+
data.edge_label_index = out.metadata[1]
140+
data.edge_label = out.metadata[2]
141+
data.edge_label_time = out.metadata[3]
136142

137143
elif isinstance(out, HeteroSamplerOutput):
138-
edge_label_index, edge_label, edge_label_time = out.metadata
139144
if isinstance(self.data, HeteroData):
140145
data = filter_hetero_data(self.data, out.node, out.row,
141146
out.col, out.edge,
@@ -144,75 +149,25 @@ def filter_fn(
144149
data = filter_custom_store(*self.data, out.node, out.row,
145150
out.col, out.edge)
146151

147-
edge_type = self.input_type
148152
for key, batch in (out.batch or {}).items():
149153
data[key].batch = batch
150-
data[edge_type].edge_label_index = edge_label_index
151-
data[edge_type].edge_label = edge_label
152-
if edge_label_time is not None:
153-
data[edge_type].edge_label_time = edge_label_time
154+
data[self.edge_type].input_links = out.metadata[0]
155+
data[self.edge_type].edge_label_index = out.metadata[1]
156+
data[self.edge_type].edge_label = out.metadata[2]
157+
data[self.edge_type].edge_label_time = out.metadata[3]
154158

155159
else:
156160
raise TypeError(f"'{self.__class__.__name__}'' found invalid "
157161
f"type: '{type(out)}'")
158162

159163
return data if self.transform is None else self.transform(data)
160164

161-
def collate_fn(self, index: EdgeSamplerInput) -> Any:
162-
r"""Samples a subgraph from a batch of input nodes."""
163-
out = self.link_sampler.sample_from_edges(
164-
index,
165-
negative_sampling_ratio=self.neg_sampling_ratio,
166-
)
167-
if self.filter_per_worker:
168-
# We execute `filter_fn` in the worker process.
169-
out = self.filter_fn(out)
170-
return out
171-
172165
def _get_iterator(self) -> Iterator:
173166
if self.filter_per_worker:
174167
return super()._get_iterator()
175-
# We execute `filter_fn` in the main process.
168+
169+
# Execute `filter_fn` in the main process:
176170
return DataLoaderIterator(super()._get_iterator(), self.filter_fn)
177171

178172
def __repr__(self) -> str:
179173
return f'{self.__class__.__name__}()'
180-
181-
182-
###############################################################################
183-
184-
185-
class Dataset(torch.utils.data.Dataset):
186-
def __init__(
187-
self,
188-
edge_label_index: torch.Tensor,
189-
edge_label: torch.Tensor,
190-
edge_label_time: OptTensor = None,
191-
):
192-
# NOTE see documentation of LinkLoader for details on these three
193-
# input parameters:
194-
self.edge_label_index = edge_label_index
195-
self.edge_label = edge_label
196-
self.edge_label_time = edge_label_time
197-
198-
def __getitem__(
199-
self,
200-
idx: int,
201-
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[
202-
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
203-
if self.edge_label_time is None:
204-
return (
205-
self.edge_label_index[0, idx],
206-
self.edge_label_index[1, idx],
207-
self.edge_label[idx],
208-
)
209-
else:
210-
return (
211-
self.edge_label_index[0, idx],
212-
self.edge_label_index[1, idx],
213-
self.edge_label[idx],
214-
self.edge_label_time[idx],
215-
)
216-
217-
def __len__(self) -> int:
218-
return self.edge_label_index.size(1)

torch_geometric/loader/link_neighbor_loader.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,21 +166,16 @@ def __init__(
166166
neighbor_sampler: Optional[NeighborSampler] = None,
167167
**kwargs,
168168
):
169-
# Get input type:
170-
# TODO(manan): this computation is required twice, once here and once
171-
# in LinkLoader:
169+
# TODO(manan): Avoid duplicated computation (here and in NodeLoader):
172170
edge_type, _ = get_edge_label_index(data, edge_label_index)
173171

174-
has_time_attr = time_attr is not None
175-
has_edge_label_time = edge_label_time is not None
176-
if has_edge_label_time != has_time_attr:
172+
if (edge_label_time is not None) != (time_attr is not None):
177173
raise ValueError(
178-
f"Received conflicting 'time_attr' and 'edge_label_time' "
179-
f"arguments: 'time_attr' was "
180-
f"{'set' if has_time_attr else 'not set'} and "
181-
f"'edge_label_time' was "
182-
f"{'set' if has_edge_label_time else 'not set'}. Please "
183-
f"resolve these conflicting arguments.")
174+
f"Received conflicting 'edge_label_time' and 'time_attr' "
175+
f"arguments: 'edge_label_time' is "
176+
f"{'set' if edge_label_time is not None else 'not set'} "
177+
f"while 'input_time' is "
178+
f"{'set' if time_attr is not None else 'not set'}.")
184179

185180
if neighbor_sampler is None:
186181
neighbor_sampler = NeighborSampler(

0 commit comments

Comments
 (0)