Skip to content

Commit 6405852

Browse files
authored
Account for unsorted inputs when computing e_id in NeighborSampler (#7953)
1 parent 2febd38 commit 6405852

File tree

8 files changed

+81
-9
lines changed

8 files changed

+81
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9191

9292
### Changed
9393

94+
- Fixed a bug in which `batch.e_id` was not correctly computed on unsorted graph inputs ([#7953](https://github.com/pyg-team/pytorch_geometric/pull/7953))
9495
- Fixed `from_networkx` conversion from `nx.stochastic_block_model` graphs ([#7941](https://github.com/pyg-team/pytorch_geometric/pull/7941))
9596
- Fixed the usage of `bias_initializer` in `HeteroLinear` ([#7923](https://github.com/pyg-team/pytorch_geometric/pull/7923))
9697
- Fixed broken links in `HGBDataset` ([#7907](https://github.com/pyg-team/pytorch_geometric/pull/7907))

test/explain/algorithm/test_graphmask_explainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_graph_mask_explainer_binary_classification(
123123

124124
explainer = Explainer(
125125
model=model,
126-
algorithm=GraphMaskExplainer(2, epochs=5),
126+
algorithm=GraphMaskExplainer(2, epochs=5, log=False),
127127
explanation_type=explanation_type,
128128
node_mask_type=node_mask_type,
129129
edge_mask_type=edge_mask_type,
@@ -171,7 +171,7 @@ def test_graph_mask_explainer_multiclass_classification(
171171

172172
explainer = Explainer(
173173
model=model,
174-
algorithm=GraphMaskExplainer(2, epochs=5),
174+
algorithm=GraphMaskExplainer(2, epochs=5, log=False),
175175
explanation_type=explanation_type,
176176
node_mask_type=node_mask_type,
177177
edge_mask_type=edge_mask_type,
@@ -216,7 +216,7 @@ def test_graph_mask_explainer_regression(
216216

217217
explainer = Explainer(
218218
model=model,
219-
algorithm=GraphMaskExplainer(2, epochs=5),
219+
algorithm=GraphMaskExplainer(2, epochs=5, log=False),
220220
explanation_type=explanation_type,
221221
node_mask_type=node_mask_type,
222222
edge_mask_type=edge_mask_type,

test/loader/test_link_neighbor_loader.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,26 @@ def test_hetero_link_neighbor_loader_triplet(disjoint, temporal, amount):
570570
for i in range(batch_size):
571571
assert (node_store.time[node_store.batch == i].max()
572572
<= node_store.seed_time[i])
573+
574+
575+
@withPackage('pyg_lib')
576+
def test_link_neighbor_loader_mapping():
577+
edge_index = torch.tensor([
578+
[0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 5],
579+
[1, 2, 3, 4, 5, 8, 6, 7, 9, 10, 6, 11],
580+
])
581+
data = Data(edge_index=edge_index, num_nodes=12)
582+
583+
loader = LinkNeighborLoader(
584+
data,
585+
edge_label_index=data.edge_index,
586+
num_neighbors=[1],
587+
batch_size=2,
588+
shuffle=True,
589+
)
590+
591+
for batch in loader:
592+
assert torch.equal(
593+
batch.n_id[batch.edge_index],
594+
data.edge_index[:, batch.e_id],
595+
)

test/loader/test_neighbor_loader.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,25 @@ def test_hetero_neighbor_loader_sampled_info():
692692
for edge_type in batch.edge_types:
693693
assert (batch[edge_type].num_sampled_edges ==
694694
expected_num_sampled_edges[edge_type])
695+
696+
697+
@withPackage('pyg_lib')
698+
def test_neighbor_loader_mapping():
699+
edge_index = torch.tensor([
700+
[0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 5],
701+
[1, 2, 3, 4, 5, 8, 6, 7, 9, 10, 6, 11],
702+
])
703+
data = Data(edge_index=edge_index, num_nodes=12)
704+
705+
loader = NeighborLoader(
706+
data,
707+
num_neighbors=[1],
708+
batch_size=2,
709+
shuffle=True,
710+
)
711+
712+
for batch in loader:
713+
assert torch.equal(
714+
batch.n_id[batch.edge_index],
715+
data.edge_index[:, batch.e_id],
716+
)

torch_geometric/loader/link_loader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,9 @@ def filter_fn(
224224
if 'n_id' not in data:
225225
data.n_id = out.node
226226
if out.edge is not None and 'e_id' not in data:
227-
data.e_id = out.edge
227+
edge = out.edge.to(torch.long)
228+
perm = self.link_sampler.edge_permutation
229+
data.e_id = perm[out.edge] if perm is not None else out.edge
228230

229231
data.batch = out.batch
230232
data.num_sampled_nodes = out.num_sampled_nodes
@@ -260,8 +262,10 @@ def filter_fn(
260262
data[key].n_id = node
261263

262264
for key, edge in (out.edge or {}).items():
263-
if 'e_id' not in data[key]:
264-
data[key].e_id = edge
265+
if edge is not None and 'e_id' not in data[key]:
266+
edge = edge.to(torch.long)
267+
perm = self.link_sampler.edge_permutation[key]
268+
data[key].e_id = perm[edge] if perm is not None else edge
265269

266270
data.set_value_dict('batch', out.batch)
267271
data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes)

torch_geometric/loader/link_neighbor_loader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ class LinkNeighborLoader(LinkLoader):
6363
The rest of the functionality mirrors that of
6464
:class:`~torch_geometric.loader.NeighborLoader`, including support for
6565
heterogeneous graphs.
66+
In particular, the data loader will add the following attributes to the
67+
returned mini-batch:
68+
69+
* :obj:`n_id` The global node index for every sampled node
70+
* :obj:`e_id` The global edge index for every sampled edge
71+
* :obj:`input_id`: The global index of the :obj:`edge_label_index`
72+
* :obj:`num_sampled_nodes`: The number of sampled nodes in each hop
73+
* :obj:`num_sampled_edges`: The number of sampled edges in each hop
6674
6775
.. note::
6876
Negative sampling is currently implemented in an approximate

torch_geometric/loader/neighbor_loader.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ class NeighborLoader(NodeLoader):
101101
sampled_data = next(iter(loader))
102102
print(sampled_data.n_id) # Global node index of each node in batch.
103103
104+
In particular, the data loader will add the following attributes to the
105+
returned mini-batch:
106+
107+
* :obj:`batch_size` The number of seed nodes (first nodes in the batch)
108+
* :obj:`n_id` The global node index for every sampled node
109+
* :obj:`e_id` The global edge index for every sampled edge
110+
* :obj:`input_id`: The global index of the :obj:`input_nodes`
111+
* :obj:`num_sampled_nodes`: The number of sampled nodes in each hop
112+
* :obj:`num_sampled_edges`: The number of sampled edges in each hop
113+
104114
Args:
105115
data (Any): A :class:`~torch_geometric.data.Data`,
106116
:class:`~torch_geometric.data.HeteroData`, or

torch_geometric/loader/node_loader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def filter_fn(
156156
if 'n_id' not in data:
157157
data.n_id = out.node
158158
if out.edge is not None and 'e_id' not in data:
159-
data.e_id = out.edge
159+
edge = out.edge.to(torch.long)
160+
perm = self.node_sampler.edge_permutation
161+
data.e_id = perm[edge] if perm is not None else edge
160162

161163
data.batch = out.batch
162164
data.num_sampled_nodes = out.num_sampled_nodes
@@ -180,8 +182,10 @@ def filter_fn(
180182
data[key].n_id = node
181183

182184
for key, edge in (out.edge or {}).items():
183-
if 'e_id' not in data[key]:
184-
data[key].e_id = edge
185+
if edge is not None and 'e_id' not in data[key]:
186+
edge = edge.to(torch.long)
187+
perm = self.node_sampler.edge_permutation[key]
188+
data[key].e_id = perm[edge] if perm is not None else edge
185189

186190
data.set_value_dict('batch', out.batch)
187191
data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes)

0 commit comments

Comments
 (0)