Skip to content

Commit 0438d3a

Browse files
kaixuanliurusty1sEdisonLeeeee
authored
Filter out empty tensors inside trim_to_layer (#7942)
filter out empty tensor for `x`, `edge_index`, `edge_attr` after calling `trim_to_layer` function. This can avoid unnecessary computation when some node/edge types get empty output. For example: when I train `igbh-tiny` dataset with 3 hops sampler and use `trim_to_layer` function, I get a lot of empty edge_index tensor for edge type '('author', 'affiliated_to', 'institute')', but the feature tensor for 'author' node type is still sent to compute in `HeteroConv` implementation. --------- Signed-off-by: Liu,Kaixuan <[email protected]> Co-authored-by: Matthias Fey <[email protected]> Co-authored-by: Jintang Li <[email protected]>
1 parent dea1577 commit 0438d3a

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9696

9797
### Changed
9898

99+
- Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942))
99100
- Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737))
100101
- Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955))
101102
- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956)

test/utils/test_trim_to_layer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,39 @@ def test_trim_to_layer_with_neighbor_loader():
197197
assert out2.size() == (2, 16)
198198

199199
assert torch.allclose(out1, out2)
200+
201+
202+
def test_trim_to_layer_filtering():
203+
x_dict = {
204+
'paper': torch.rand((13, 128)),
205+
'author': torch.rand((5, 128)),
206+
'field_of_study': torch.rand((6, 128))
207+
}
208+
edge_index_dict = {
209+
('author', 'writes', 'paper'):
210+
torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 2]]),
211+
('paper', 'has_topic', 'field_of_study'):
212+
torch.tensor([[6, 7, 8, 9], [0, 0, 1, 1]])
213+
}
214+
num_sampled_nodes_dict = {
215+
'paper': [1, 2, 10],
216+
'author': [0, 2, 3],
217+
'field_of_study': [0, 2, 4]
218+
}
219+
num_sampled_edges_dict = {
220+
('author', 'writes', 'paper'): [2, 3],
221+
('paper', 'has_topic', 'field_of_study'): [0, 4]
222+
}
223+
x_dict, edge_index_dict, _ = trim_to_layer(
224+
layer=1,
225+
num_sampled_nodes_per_hop=num_sampled_nodes_dict,
226+
num_sampled_edges_per_hop=num_sampled_edges_dict,
227+
x=x_dict,
228+
edge_index=edge_index_dict,
229+
)
230+
assert list(edge_index_dict.keys()) == [('author', 'writes', 'paper')]
231+
assert torch.equal(edge_index_dict[('author', 'writes', 'paper')],
232+
torch.tensor([[0, 1], [0, 0]]))
233+
assert x_dict['paper'].size() == (3, 128)
234+
assert x_dict['author'].size() == (2, 128)
235+
assert x_dict['field_of_study'].size() == (2, 128)

torch_geometric/utils/trim_to_layer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Dict, List, Optional, Tuple, Union
1+
import copy
2+
from typing import Any, Dict, List, Optional, Tuple, Union
23

34
import torch
45
from torch import Tensor
@@ -14,6 +15,17 @@
1415
)
1516

1617

18+
def filter_empty_entries(
19+
input_dict: Dict[Union[Any], Tensor]) -> Dict[Any, Tensor]:
20+
r"""Removes empty tensors from a dictionary. This avoids unnecessary
21+
computation when some node/edge types are non-reachable after trimming."""
22+
out_dict = copy.copy(input_dict)
23+
for key, value in input_dict.items():
24+
if value.numel() == 0:
25+
del out_dict[key]
26+
return out_dict
27+
28+
1729
def trim_to_layer(
1830
layer: int,
1931
num_sampled_nodes_per_hop: Union[List[int], Dict[NodeType, List[int]]],
@@ -53,6 +65,8 @@ def trim_to_layer(
5365
k: trim_feat(v, layer, num_sampled_nodes_per_hop[k])
5466
for k, v in x.items()
5567
}
68+
x = filter_empty_entries(x)
69+
5670
edge_index = {
5771
k:
5872
trim_adj(
@@ -64,11 +78,15 @@ def trim_to_layer(
6478
)
6579
for k, v in edge_index.items()
6680
}
81+
edge_index = filter_empty_entries(edge_index)
82+
6783
if edge_attr is not None:
6884
edge_attr = {
6985
k: trim_feat(v, layer, num_sampled_edges_per_hop[k])
7086
for k, v in edge_attr.items()
7187
}
88+
edge_attr = filter_empty_entries(edge_attr)
89+
7290
return x, edge_index, edge_attr
7391

7492
x = trim_feat(x, layer, num_sampled_nodes_per_hop)

0 commit comments

Comments
 (0)