Skip to content

Commit 41fd354

Browse files
Add PyTorch SparseTensor support for MessagePassing (#5944)
This PR adds the PyTorch SparseTensor support for the base layer `MessagePassing`. There are some points to be confirmed (as marked with TODO): + ~~In `__collect__`: Since `adj._values()` returns a detached tensor, should we use a coalesced matrix instead (e.g., `adj.coalesce().values()`)? This is for the case of computing sparse gradients of `adj`.~~ (Solved) + In `__collect__`: Should we store the `ptr` for PyTorch SparseTensor when fused aggregation is not available? + In `__lift__`: Should we use `gather_csr` for PyTorch SparseTensor? Also, `torch.jit.script` is not available for PyTorch SparseTensor. Will figure it out soon. Co-authored-by: rusty1s <[email protected]>
1 parent c9608f1 commit 41fd354

File tree

3 files changed

+95
-28
lines changed

3 files changed

+95
-28
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Add `to_fixed_size` graph transformer ([#5939](https://github.com/pyg-team/pytorch_geometric/pull/5939))
1313
- Add support for symbolic tracing of `SchNet` model ([#5938](https://github.com/pyg-team/pytorch_geometric/pull/5938))
1414
- Add support for customizable interaction graph in `SchNet` model ([#5919](https://github.com/pyg-team/pytorch_geometric/pull/5919))
15-
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906))
15+
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944))
1616
- Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903))
1717
- Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886))
1818
- Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888))

test/nn/conv/test_message_passing.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from torch.nn import Linear
88
from torch_scatter import scatter
99
from torch_sparse import SparseTensor
10-
from torch_sparse.matmul import spmm
1110

1211
from torch_geometric.nn import MessagePassing, aggr
1312
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
13+
from torch_geometric.utils import spmm
1414

1515

1616
class MyConv(MessagePassing):
@@ -55,29 +55,44 @@ def test_my_conv():
5555
row, col = edge_index
5656
value = torch.randn(row.size(0))
5757
adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
58+
torch_adj = adj.to_torch_sparse_coo_tensor()
5859

5960
conv = MyConv(8, 32)
6061
out = conv(x1, edge_index, value)
6162
assert out.size() == (4, 32)
62-
assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist()
63-
assert conv(x1, adj.t()).tolist() == out.tolist()
63+
assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out)
64+
assert torch.allclose(conv(x1, adj.t()), out)
65+
assert torch.allclose(conv(x1, torch_adj.t()), out)
6466
conv.fuse = False
65-
assert conv(x1, adj.t()).tolist() == out.tolist()
67+
assert torch.allclose(conv(x1, adj.t()), out)
68+
assert torch.allclose(conv(x1, torch_adj.t()), out)
6669
conv.fuse = True
6770

6871
adj = adj.sparse_resize((4, 2))
72+
torch_adj = adj.to_torch_sparse_coo_tensor()
73+
6974
conv = MyConv((8, 16), 32)
7075
out1 = conv((x1, x2), edge_index, value)
7176
out2 = conv((x1, None), edge_index, value, (4, 2))
7277
assert out1.size() == (2, 32)
7378
assert out2.size() == (2, 32)
74-
assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist()
75-
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
76-
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
79+
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)
80+
assert torch.allclose(conv((x1, x2), adj.t()), out1)
81+
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1)
82+
assert torch.allclose(conv((x1, None), adj.t()), out2)
83+
assert torch.allclose(conv((x1, None), torch_adj.t()), out2)
7784
conv.fuse = False
78-
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
79-
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
85+
assert torch.allclose(conv((x1, x2), adj.t()), out1)
86+
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1)
87+
assert torch.allclose(conv((x1, None), adj.t()), out2)
88+
assert torch.allclose(conv((x1, None), torch_adj.t()), out2)
89+
conv.fuse = True
90+
91+
# Test backward compatibility for `torch.sparse` tensors:
8092
conv.fuse = True
93+
torch_adj = torch_adj.requires_grad_()
94+
conv((x1, x2), torch_adj.t()).sum().backward()
95+
assert torch_adj.grad is not None
8196

8297

8398
def test_my_conv_out_of_bounds():
@@ -197,11 +212,13 @@ def test_my_multiple_aggr_conv(multi_aggr_tuple):
197212
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
198213
row, col = edge_index
199214
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
215+
torch_adj = adj.to_torch_sparse_coo_tensor()
200216

201217
conv = MyMultipleAggrConv(aggr_kwargs=aggr_kwargs)
202218
out = conv(x, edge_index)
203219
assert out.size() == (4, 16 * expand)
204220
assert torch.allclose(conv(x, adj.t()), out)
221+
assert torch.allclose(conv(x, torch_adj.t()), out)
205222

206223

207224
def test_my_multiple_aggr_conv_jittable():
@@ -264,6 +281,7 @@ def test_my_edge_conv():
264281
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
265282
row, col = edge_index
266283
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
284+
torch_adj = adj.to_torch_sparse_coo_tensor()
267285

268286
expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='add')
269287

@@ -272,6 +290,7 @@ def test_my_edge_conv():
272290
assert out.size() == (4, 16)
273291
assert torch.allclose(out, expected)
274292
assert torch.allclose(conv(x, adj.t()), out)
293+
assert torch.allclose(conv(x, torch_adj.t()), out)
275294

276295

277296
def test_my_edge_conv_jittable():
@@ -425,10 +444,12 @@ def test_my_default_arg_conv():
425444
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
426445
row, col = edge_index
427446
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
447+
torch_adj = adj.to_torch_sparse_coo_tensor()
428448

429449
conv = MyDefaultArgConv()
430450
assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0]
431451
assert conv(x, adj.t()).view(-1).tolist() == [0, 0, 0, 0]
452+
assert conv(x, torch_adj.t()).view(-1).tolist() == [0, 0, 0, 0]
432453

433454

434455
def test_my_default_arg_conv_jittable():

torch_geometric/nn/conv/message_passing.py

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
2727
from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
2828
from torch_geometric.typing import Adj, Size
29+
from torch_geometric.utils import is_torch_sparse_tensor
2930

3031
from .utils.helpers import expand_left
3132
from .utils.inspector import Inspector, func_body_repr, func_header_repr
@@ -182,7 +183,18 @@ def __init__(
182183
def __check_input__(self, edge_index, size):
183184
the_size: List[Optional[int]] = [None, None]
184185

185-
if isinstance(edge_index, Tensor):
186+
if is_torch_sparse_tensor(edge_index):
187+
if self.flow == 'target_to_source':
188+
raise ValueError(
189+
('Flow direction "target_to_source" is invalid for '
190+
'message propagation via `torch.sparse.Tensor`. If '
191+
'you really want to make use of a reverse message '
192+
'passing flow, pass in the transposed sparse tensor to '
193+
'the message passing module, e.g., `adj_t.t()`.'))
194+
the_size[0] = edge_index.size(1)
195+
the_size[1] = edge_index.size(0)
196+
return the_size
197+
elif isinstance(edge_index, Tensor):
186198
int_dtypes = (torch.uint8, torch.int8, torch.int32, torch.int64)
187199

188200
if edge_index.dtype not in int_dtypes:
@@ -214,8 +226,8 @@ def __check_input__(self, edge_index, size):
214226

215227
raise ValueError(
216228
('`MessagePassing.propagate` only supports integer tensors of '
217-
'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for '
218-
'argument `edge_index`.'))
229+
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
230+
'`torch.sparse.Tensor` for argument `edge_index`.'))
219231

220232
def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor):
221233
the_size = size[dim]
@@ -227,7 +239,12 @@ def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor):
227239
f'dimension {self.node_dim}, but expected size {the_size}.'))
228240

229241
def __lift__(self, src, edge_index, dim):
230-
if isinstance(edge_index, Tensor):
242+
if is_torch_sparse_tensor(edge_index):
243+
assert dim == 0 or dim == 1
244+
index = edge_index._indices()[1 - dim]
245+
return src.index_select(self.node_dim, index)
246+
247+
elif isinstance(edge_index, Tensor):
231248
try:
232249
index = edge_index[dim]
233250
return src.index_select(self.node_dim, index)
@@ -270,8 +287,8 @@ def __lift__(self, src, edge_index, dim):
270287

271288
raise ValueError(
272289
('`MessagePassing.propagate` only supports integer tensors of '
273-
'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for '
274-
'argument `edge_index`.'))
290+
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` '
291+
'or `torch.sparse.Tensor` for argument `edge_index`.'))
275292

276293
def __collect__(self, args, edge_index, size, kwargs):
277294
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
@@ -296,12 +313,33 @@ def __collect__(self, args, edge_index, size, kwargs):
296313

297314
out[arg] = data
298315

299-
if isinstance(edge_index, Tensor):
316+
if is_torch_sparse_tensor(edge_index):
317+
if edge_index.requires_grad:
318+
edge_index = edge_index.coalesce()
319+
indices = edge_index.indices()
320+
values = edge_index.values()
321+
else:
322+
indices = edge_index._indices()
323+
values = edge_index._values()
324+
out['adj_t'] = edge_index
325+
out['edge_index'] = None
326+
out['edge_index_i'] = indices[0]
327+
out['edge_index_j'] = indices[1]
328+
out['ptr'] = None # TODO Get `rowptr` from CSR representation.
329+
if out.get('edge_weight', None) is None:
330+
out['edge_weight'] = values
331+
if out.get('edge_attr', None) is None:
332+
out['edge_attr'] = values
333+
if out.get('edge_type', None) is None:
334+
out['edge_type'] = values
335+
336+
elif isinstance(edge_index, Tensor):
300337
out['adj_t'] = None
301338
out['edge_index'] = edge_index
302339
out['edge_index_i'] = edge_index[i]
303340
out['edge_index_j'] = edge_index[j]
304341
out['ptr'] = None
342+
305343
elif isinstance(edge_index, SparseTensor):
306344
out['adj_t'] = edge_index
307345
out['edge_index'] = None
@@ -327,8 +365,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
327365
r"""The initial call to start propagating messages.
328366
329367
Args:
330-
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
331-
:obj:`torch_sparse.SparseTensor` that defines the underlying
368+
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a
369+
:obj:`torch_sparse.SparseTensor` or a
370+
:obj:`torch.sparse.Tensor` that defines the underlying
332371
graph connectivity/message passing flow.
333372
:obj:`edge_index` holds the indices of a general (sparse)
334373
assignment matrix of shape :obj:`[N, M]`.
@@ -338,9 +377,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
338377
nodes in :obj:`edge_index[1]`
339378
(in case :obj:`flow="source_to_target"`).
340379
If :obj:`edge_index` is of type
341-
:obj:`torch_sparse.SparseTensor`, its sparse indices
342-
:obj:`(row, col)` should relate to :obj:`row = edge_index[1]`
343-
and :obj:`col = edge_index[0]`.
380+
:obj:`torch_sparse.SparseTensor` or :obj:`torch.sparse.Tensor`,
381+
its sparse indices :obj:`(row, col)` should relate to
382+
:obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`.
344383
The major difference between both formats is that we need to
345384
input the *transposed* sparse adjacency matrix into
346385
:func:`propagate`.
@@ -349,7 +388,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
349388
If set to :obj:`None`, the size will be automatically inferred
350389
and assumed to be quadratic.
351390
This argument is ignored in case :obj:`edge_index` is a
352-
:obj:`torch_sparse.SparseTensor`. (default: :obj:`None`)
391+
:obj:`torch_sparse.SparseTensor` or
392+
a :obj:`torch.sparse.Tensor`. (default: :obj:`None`)
353393
**kwargs: Any additional data which is needed to construct and
354394
aggregate messages, and to update node embeddings.
355395
"""
@@ -363,7 +403,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
363403
size = self.__check_input__(edge_index, size)
364404

365405
# Run "fused" message and aggregation (if applicable).
366-
if (isinstance(edge_index, SparseTensor) and self.fuse
406+
if ((isinstance(edge_index, SparseTensor)
407+
or is_torch_sparse_tensor(edge_index)) and self.fuse
367408
and not self.explain):
368409
coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
369410
size, kwargs)
@@ -451,8 +492,9 @@ def edge_updater(self, edge_index: Adj, **kwargs):
451492
graph.
452493
453494
Args:
454-
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
455-
:obj:`torch_sparse.SparseTensor` that defines the underlying
495+
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a
496+
:obj:`torch_sparse.SparseTensor` or
497+
a :obj:`torch.sparse.Tensor` that defines the underlying
456498
graph connectivity/message passing flow.
457499
See :meth:`propagate` for more information.
458500
**kwargs: Any additional data which is needed to compute or update
@@ -549,13 +591,17 @@ def aggregate(self, inputs: Tensor, index: Tensor,
549591
return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
550592
dim=self.node_dim)
551593

552-
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
594+
def message_and_aggregate(
595+
self,
596+
adj_t: Union[SparseTensor, Tensor],
597+
) -> Tensor:
553598
r"""Fuses computations of :func:`message` and :func:`aggregate` into a
554599
single function.
555600
If applicable, this saves both time and memory since messages do not
556601
explicitly need to be materialized.
557602
This function will only gets called in case it is implemented and
558-
propagation takes place based on a :obj:`torch_sparse.SparseTensor`.
603+
propagation takes place based on a :obj:`torch_sparse.SparseTensor`
604+
or a :obj:`torch.sparse.Tensor`.
559605
"""
560606
raise NotImplementedError
561607

0 commit comments

Comments
 (0)