Skip to content

Commit d22ae19

Browse files
committed
update
1 parent 8874b73 commit d22ae19

File tree

9 files changed

+40
-29
lines changed

9 files changed

+40
-29
lines changed

test/loader/test_hgt_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def test_hgt_loader():
6060
assert set(batch.node_types) == {'paper', 'author'}
6161
assert set(batch.edge_types) == set(data.edge_types)
6262

63-
assert len(batch['paper']) == 2
63+
assert len(batch['paper']) == 3
6464
assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5
65+
assert batch['paper'].input_nodes.numel() == batch_size
6566
assert batch['paper'].batch_size == batch_size
6667
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100
6768

test/loader/test_link_neighbor_loader.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
5151
for batch in loader:
5252
assert isinstance(batch, Data)
5353

54-
assert len(batch) == 5
54+
assert len(batch) == 6
5555
assert batch.x.size(0) <= 100
5656
assert batch.x.min() >= 0 and batch.x.max() < 100
57+
assert batch.input_links.numel() == 20
5758
assert batch.edge_index.min() >= 0
5859
assert batch.edge_index.max() < batch.num_nodes
5960
assert batch.edge_attr.min() >= 0
@@ -110,7 +111,7 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
110111

111112
for batch in loader:
112113
assert isinstance(batch, HeteroData)
113-
assert len(batch) == 5
114+
assert len(batch) == 6
114115
if neg_sampling_ratio == 0.0:
115116
# Assert only positive samples are present in the original graph:
116117
assert batch['paper', 'author'].edge_label.sum() == 0
@@ -120,7 +121,6 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
120121
assert len(edge_index | edge_label_index) == len(edge_index)
121122

122123
else:
123-
124124
assert batch['paper', 'author'].edge_label_index.size(1) == 40
125125
assert torch.all(batch['paper', 'author'].edge_label[:20] == 1)
126126
assert torch.all(batch['paper', 'author'].edge_label[20:] == 0)
@@ -312,7 +312,8 @@ def test_homogeneous_link_neighbor_loader_no_edges():
312312

313313
for batch in loader:
314314
assert isinstance(batch, Data)
315-
assert len(batch) == 3
315+
assert len(batch) == 4
316+
assert batch.input_links.numel() == 20
316317
assert batch.num_nodes <= 40
317318
assert batch.edge_label_index.size(1) == 20
318319
assert batch.num_nodes == batch.edge_label_index.unique().numel()
@@ -328,8 +329,9 @@ def test_heterogeneous_link_neighbor_loader_no_edges():
328329

329330
for batch in loader:
330331
assert isinstance(batch, HeteroData)
331-
assert len(batch) == 3
332+
assert len(batch) == 4
332333
assert batch['paper'].num_nodes <= 40
334+
assert batch['paper', 'paper'].input_links.numel() == 20
333335
assert batch['paper', 'paper'].edge_label_index.size(1) == 20
334336
assert batch['paper'].num_nodes == batch[
335337
'paper', 'paper'].edge_label_index.unique().numel()

test/loader/test_neighbor_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ def test_homogeneous_neighbor_loader(directed):
4848

4949
for batch in loader:
5050
assert isinstance(batch, Data)
51-
52-
assert len(batch) == 4
51+
assert len(batch) == 5
5352
assert batch.x.size(0) <= 100
54-
assert batch.batch_size == 20
53+
assert batch.input_nodes.numel() == batch.batch_size == 20
5554
assert batch.x.min() >= 0 and batch.x.max() < 100
5655
assert batch.edge_index.min() >= 0
5756
assert batch.edge_index.max() < batch.num_nodes
@@ -118,8 +117,9 @@ def test_heterogeneous_neighbor_loader(directed):
118117
# Test node type selection:
119118
assert set(batch.node_types) == {'paper', 'author'}
120119

121-
assert len(batch['paper']) == 2
120+
assert len(batch['paper']) == 3
122121
assert batch['paper'].x.size(0) <= 100
122+
assert batch['paper'].input_nodes.numel() == batch_size
123123
assert batch['paper'].batch_size == batch_size
124124
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100
125125

torch_geometric/loader/link_loader.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,10 @@ def filter_fn(
135135
self.link_sampler.edge_permutation)
136136

137137
data.batch = out.batch
138-
data.edge_label_index = out.metadata[0]
139-
data.edge_label = out.metadata[1]
140-
data.edge_label_time = out.metadata[2]
138+
data.input_links = out.metadata[0]
139+
data.edge_label_index = out.metadata[1]
140+
data.edge_label = out.metadata[2]
141+
data.edge_label_time = out.metadata[3]
141142

142143
elif isinstance(out, HeteroSamplerOutput):
143144
if isinstance(self.data, HeteroData):
@@ -150,9 +151,10 @@ def filter_fn(
150151

151152
for key, batch in (out.batch or {}).items():
152153
data[key].batch = batch
153-
data[self.edge_type].edge_label_index = out.metadata[0]
154-
data[self.edge_type].edge_label = out.metadata[1]
155-
data[self.edge_type].edge_label_time = out.metadata[2]
154+
data[self.edge_type].input_links = out.metadata[0]
155+
data[self.edge_type].edge_label_index = out.metadata[1]
156+
data[self.edge_type].edge_label = out.metadata[2]
157+
data[self.edge_type].edge_label_time = out.metadata[3]
156158

157159
else:
158160
raise TypeError(f"'{self.__class__.__name__}'' found invalid "

torch_geometric/loader/node_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def filter_fn(
116116
data = filter_data(self.data, out.node, out.row, out.col, out.edge,
117117
self.node_sampler.edge_permutation)
118118
data.batch = out.batch
119-
data.batch_size = out.metadata
119+
data.input_nodes = out.metadata
120+
data.batch_size = out.metadata.size(0)
120121

121122
elif isinstance(out, HeteroSamplerOutput):
122123
if isinstance(self.data, HeteroData):
@@ -129,7 +130,8 @@ def filter_fn(
129130

130131
for key, batch in (out.batch or {}).items():
131132
data[key].batch = batch
132-
data[self.node_type].batch_size = out.metadata
133+
data[self.node_type].input_nodes = out.metadata
134+
data[self.node_type].batch_size = out.metadata.size(0)
133135

134136
else:
135137
raise TypeError(f"'{self.__class__.__name__}'' found invalid "

torch_geometric/loader/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __getitem__(self, index: Union[Tensor, List[int]]) -> Any:
2828
if not isinstance(index, Tensor):
2929
index = torch.tensor(index, dtype=torch.long)
3030

31-
outs = []
31+
outs = [index]
3232
for arg in self.args:
3333
outs.append(arg[index] if arg is not None else None)
3434
return tuple(outs)

torch_geometric/sampler/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
from torch_geometric.typing import EdgeType, NodeType, OptTensor
88

99
# An input to a node-based sampler consists of two tensors:
10+
# * The example indices
1011
# * The node indices
1112
# * The timestamps of the given node indices (optional)
12-
NodeSamplerInput = Tuple[Tensor, OptTensor]
13+
NodeSamplerInput = Tuple[Tensor, Tensor, OptTensor]
1314

1415
# An input to an edge-based sampler consists of four tensors:
16+
# * The example indices
1517
# * The row of the edge index in COO format
1618
# * The column of the edge index in COO format
1719
# * The labels of the edges
1820
# * The time attribute corresponding to the edge label (optional)
19-
EdgeSamplerInput = Tuple[Tensor, Tensor, Tensor, OptTensor]
21+
EdgeSamplerInput = Tuple[Tensor, Tensor, Tensor, Tensor, OptTensor]
2022

2123

2224
# A sampler output contains the following information.

torch_geometric/sampler/hgt_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def sample_from_nodes(
6060
index: NodeSamplerInput,
6161
**kwargs,
6262
) -> HeteroSamplerOutput:
63-
input_nodes, _ = index
63+
index, input_nodes, _ = index
6464
input_node_dict = {self.input_type: input_nodes}
6565
sample_fn = torch.ops.torch_sparse.hgt_sample
6666
out = sample_fn(
@@ -77,7 +77,7 @@ def sample_from_nodes(
7777
col=remap_keys(col, self.to_edge_type),
7878
edge=remap_keys(edge, self.to_edge_type),
7979
batch=batch,
80-
metadata=input_nodes.size(0),
80+
metadata=index,
8181
)
8282

8383
def sample_from_edges(

torch_geometric/sampler/neighbor_sampler.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,19 +326,19 @@ def sample_from_nodes(
326326
index: NodeSamplerInput,
327327
**kwargs,
328328
) -> Union[SamplerOutput, HeteroSamplerOutput]:
329-
input_nodes, input_time = index
329+
index, input_nodes, input_time = index
330330

331331
if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData):
332332
seed_time_dict = None
333333
if input_time is not None:
334334
seed_time_dict = {self.input_type: input_time}
335335
output = self._sample(seed={self.input_type: input_nodes},
336336
seed_time_dict=seed_time_dict)
337-
output.metadata = input_nodes.numel()
337+
output.metadata = index
338338

339339
elif issubclass(self.data_cls, Data):
340340
output = self._sample(seed=input_nodes, seed_time=input_time)
341-
output.metadata = input_nodes.numel()
341+
output.metadata = index
342342

343343
else:
344344
raise TypeError(f"'{self.__class__.__name__}'' found invalid "
@@ -353,7 +353,7 @@ def sample_from_edges(
353353
index: EdgeSamplerInput,
354354
**kwargs,
355355
) -> Union[SamplerOutput, HeteroSamplerOutput]:
356-
row, col, edge_label, edge_label_time = index
356+
index, row, col, edge_label, edge_label_time = index
357357
edge_label_index = torch.stack([row, col], dim=0)
358358
negative_sampling_ratio = kwargs.get('negative_sampling_ratio', 0.0)
359359

@@ -421,7 +421,8 @@ def sample_from_edges(
421421
for key, batch in output.batch.items():
422422
output.batch[key] = batch % num_seed_edges
423423

424-
output.metadata = (edge_label_index, edge_label, edge_label_time)
424+
output.metadata = (index, edge_label_index, edge_label,
425+
edge_label_time)
425426

426427
elif issubclass(self.data_cls, Data):
427428
if self.disjoint_sampling:
@@ -441,7 +442,8 @@ def sample_from_edges(
441442
if self.disjoint_sampling:
442443
output.batch = output.batch % num_seed_edges
443444

444-
output.metadata = (edge_label_index, edge_label, edge_label_time)
445+
output.metadata = (index, edge_label_index, edge_label,
446+
edge_label_time)
445447

446448
else:
447449
raise TypeError(f"'{self.__class__.__name__}'' found invalid "

0 commit comments

Comments
 (0)