Skip to content

Commit cea23b2

Browse files
rusty1sJakubPietrakIntel
authored andcommitted
Weighted sampling in NeighborLoader and LinkNeighborLoader (#8038)
1 parent cfcec60 commit cea23b2

File tree

5 files changed

+128
-6
lines changed

5 files changed

+128
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
1011
- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
1112
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033))
1213
- Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230))

test/loader/test_neighbor_loader.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
withCUDA,
2121
withPackage,
2222
)
23-
from torch_geometric.typing import WITH_PYG_LIB, WITH_TORCH_SPARSE
23+
from torch_geometric.typing import (
24+
WITH_PYG_LIB,
25+
WITH_TORCH_SPARSE,
26+
WITH_WEIGHTED_NEIGHBOR_SAMPLE,
27+
)
2428
from torch_geometric.utils import (
2529
is_undirected,
2630
sort_edge_index,
@@ -714,3 +718,67 @@ def test_neighbor_loader_mapping():
714718
batch.n_id[batch.edge_index],
715719
data.edge_index[:, batch.e_id],
716720
)
721+
722+
723+
@pytest.mark.skipif(
724+
not WITH_WEIGHTED_NEIGHBOR_SAMPLE,
725+
reason="'pyg-lib' does not support weighted neighbor sampling",
726+
)
727+
def test_weighted_homo_neighbor_loader():
728+
edge_index = torch.tensor([
729+
[1, 3, 0, 4],
730+
[2, 2, 1, 3],
731+
])
732+
edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0])
733+
734+
data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight)
735+
736+
loader = NeighborLoader(
737+
data,
738+
input_nodes=torch.tensor([2]),
739+
num_neighbors=[1] * 2,
740+
batch_size=1,
741+
weight_attr='edge_weight',
742+
)
743+
assert len(loader) == 1
744+
745+
batch = next(iter(loader))
746+
747+
assert batch.num_nodes == 3
748+
assert batch.n_id.tolist() == [2, 3, 4]
749+
assert batch.num_edges == 2
750+
assert batch.n_id[batch.edge_index].tolist() == [[3, 4], [2, 3]]
751+
752+
753+
@pytest.mark.skipif(
754+
not WITH_WEIGHTED_NEIGHBOR_SAMPLE,
755+
reason="'pyg-lib' does not support weighted neighbor sampling",
756+
)
757+
def test_weighted_hetero_neighbor_loader():
758+
edge_index = torch.tensor([
759+
[1, 3, 0, 4],
760+
[2, 2, 1, 3],
761+
])
762+
edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0])
763+
764+
data = HeteroData()
765+
data['paper'].num_nodes = 5
766+
data['paper', 'to', 'paper'].edge_index = edge_index
767+
data['paper', 'to', 'paper'].edge_weight = edge_weight
768+
769+
loader = NeighborLoader(
770+
data,
771+
input_nodes=('paper', torch.tensor([2])),
772+
num_neighbors=[1] * 2,
773+
batch_size=1,
774+
weight_attr='edge_weight',
775+
)
776+
assert len(loader) == 1
777+
778+
batch = next(iter(loader))
779+
780+
assert batch['paper'].num_nodes == 3
781+
assert batch['paper'].n_id.tolist() == [2, 3, 4]
782+
assert batch['paper', 'paper'].num_edges == 2
783+
global_edge_index = batch['paper'].n_id[batch['paper', 'paper'].edge_index]
784+
assert global_edge_index.tolist() == [[3, 4], [2, 3]]

torch_geometric/loader/link_neighbor_loader.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ class LinkNeighborLoader(LinkLoader):
165165
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
166166
an earlier or equal timestamp than the center node.
167167
Only used if :obj:`edge_label_time` is set. (default: :obj:`None`)
168+
weight_attr (str, optional): The name of the attribute that denotes
169+
edge weights in the graph.
170+
If set, weighted/biased sampling will be used such that neighbors
171+
are more likely to get sampled the higher their edge weights are.
172+
Edge weights do not need to sum to one, but must be non-negative,
173+
finite and have a non-zero sum within local neighborhoods.
174+
(default: :obj:`None`)
168175
transform (callable, optional): A function/transform that takes in
169176
a sampled mini-batch and returns a transformed version.
170177
(default: :obj:`None`)
@@ -207,6 +214,7 @@ def __init__(
207214
neg_sampling: Optional[NegativeSampling] = None,
208215
neg_sampling_ratio: Optional[Union[int, float]] = None,
209216
time_attr: Optional[str] = None,
217+
weight_attr: Optional[str] = None,
210218
transform: Optional[Callable] = None,
211219
transform_sampler_output: Optional[Callable] = None,
212220
is_sorted: bool = False,
@@ -233,6 +241,7 @@ def __init__(
233241
disjoint=disjoint,
234242
temporal_strategy=temporal_strategy,
235243
time_attr=time_attr,
244+
weight_attr=weight_attr,
236245
is_sorted=is_sorted,
237246
share_memory=kwargs.get('num_workers', 0) > 0,
238247
directed=directed,

torch_geometric/loader/neighbor_loader.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ class NeighborLoader(NodeLoader):
165165
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
166166
an earlier or equal timestamp than the center node.
167167
(default: :obj:`None`)
168+
weight_attr (str, optional): The name of the attribute that denotes
169+
edge weights in the graph.
170+
If set, weighted/biased sampling will be used such that neighbors
171+
are more likely to get sampled the higher their edge weights are.
172+
Edge weights do not need to sum to one, but must be non-negative,
173+
finite and have a non-zero sum within local neighborhoods.
174+
(default: :obj:`None`)
168175
transform (callable, optional): A function/transform that takes in
169176
a sampled mini-batch and returns a transformed version.
170177
(default: :obj:`None`)
@@ -204,6 +211,7 @@ def __init__(
204211
disjoint: bool = False,
205212
temporal_strategy: str = 'uniform',
206213
time_attr: Optional[str] = None,
214+
weight_attr: Optional[str] = None,
207215
transform: Optional[Callable] = None,
208216
transform_sampler_output: Optional[Callable] = None,
209217
is_sorted: bool = False,
@@ -226,6 +234,7 @@ def __init__(
226234
disjoint=disjoint,
227235
temporal_strategy=temporal_strategy,
228236
time_attr=time_attr,
237+
weight_attr=weight_attr,
229238
is_sorted=is_sorted,
230239
share_memory=kwargs.get('num_workers', 0) > 0,
231240
directed=directed,

torch_geometric/sampler/neighbor_sampler.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
disjoint: bool = False,
4444
temporal_strategy: str = 'uniform',
4545
time_attr: Optional[str] = None,
46+
weight_attr: Optional[str] = None,
4647
is_sorted: bool = False,
4748
share_memory: bool = False,
4849
# Deprecated:
@@ -65,18 +66,30 @@ def __init__(
6566

6667
if self.data_type == DataType.homogeneous:
6768
self.num_nodes = data.num_nodes
68-
self.node_time = data[time_attr] if time_attr else None
69+
70+
self.node_time: Optional[Tensor] = None
71+
if time_attr is not None:
72+
self.node_time = data[time_attr]
6973

7074
# Convert the graph data into CSC format for sampling:
7175
self.colptr, self.row, self.perm = to_csc(
7276
data, device='cpu', share_memory=share_memory,
7377
is_sorted=is_sorted, src_node_time=self.node_time)
7478

79+
self.edge_weight: Optional[Tensor] = None
80+
if weight_attr is not None:
81+
self.edge_weight = data[weight_attr]
82+
if self.perm is not None:
83+
self.edge_weight = self.edge_weight[self.perm]
84+
7585
elif self.data_type == DataType.heterogeneous:
7686
self.node_types, self.edge_types = data.metadata()
7787

7888
self.num_nodes = {k: data[k].num_nodes for k in self.node_types}
79-
self.node_time = data.collect(time_attr) if time_attr else None
89+
90+
self.node_time: Optional[Dict[NodeType, Tensor]] = None
91+
if time_attr is not None:
92+
self.node_time = data.collect(time_attr)
8093

8194
# Conversion to/from C++ string type: Since C++ cannot take
8295
# dictionaries with tuples as key as input, edge type triplets need
@@ -91,6 +104,16 @@ def __init__(
91104
self.row_dict = remap_keys(row_dict, self.to_rel_type)
92105
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
93106

107+
self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None
108+
if weight_attr is not None:
109+
self.edge_weight = data.collect(weight_attr)
110+
for edge_type, edge_weight in self.edge_weight.items():
111+
if self.perm.get(edge_type, None) is not None:
112+
edge_weight = edge_weight[self.perm[edge_type]]
113+
self.edge_weight[edge_type] = edge_weight
114+
self.edge_weight = remap_keys(self.edge_weight,
115+
self.to_rel_type)
116+
94117
else: # self.data_type == DataType.remote
95118
feature_store, graph_store = data
96119

@@ -106,7 +129,7 @@ def __init__(
106129
for node_type in self.node_types
107130
}
108131

109-
self.node_time: Optional[Dict[str, Tensor]] = None
132+
self.node_time: Optional[Dict[NodeType, Tensor]] = None
110133
if time_attr is not None:
111134
# If the `time_attr` is present, we expect that `GraphStore`
112135
# holds all edges sorted by destination, and within local
@@ -136,6 +159,13 @@ def __init__(
136159
for time_attr, time_tensor in zip(time_attrs, time_tensors)
137160
}
138161

162+
self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None
163+
if weight_attr is not None:
164+
raise NotImplementedError(
165+
f"'weight_attr' argument not yet supported within "
166+
f"'{self.__class__.__name__}' for "
167+
f"'(FeatureStore, GraphStore)' inputs")
168+
139169
# Conversion to/from C++ string type (see above):
140170
self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}
141171
self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}
@@ -145,6 +175,11 @@ def __init__(
145175
self.row_dict = remap_keys(row_dict, self.to_rel_type)
146176
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)
147177

178+
if (self.edge_weight is not None
179+
and not torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE):
180+
raise ImportError("Weighted neighbor sampling requires "
181+
"'pyg-lib>=0.3.0'")
182+
148183
self.num_neighbors = num_neighbors
149184
self.replace = replace
150185
self.subgraph_type = SubgraphType(subgraph_type)
@@ -233,7 +268,7 @@ def _sample(
233268
seed_time,
234269
)
235270
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
236-
args += (None, )
271+
args += (self.edge_weight, )
237272
args += (
238273
True, # csc
239274
self.replace,
@@ -313,7 +348,7 @@ def _sample(
313348
seed_time,
314349
)
315350
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
316-
args += (None, )
351+
args += (self.edge_weight, )
317352
args += (
318353
True, # csc
319354
self.replace,

0 commit comments

Comments
 (0)