Skip to content

Add PyTorch SparseTensor support for MessagePassing #5944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 15, 2022
Merged

Conversation

EdisonLeeeee
Copy link
Contributor

@EdisonLeeeee EdisonLeeeee commented Nov 10, 2022

This PR adds the PyTorch SparseTensor support for the base layer MessagePassing. There are some points to be confirmed (as marked with TODO):

  • In __collect__: Since adj._values() returns a detached tensor, should we use a coalesced matrix instead (e.g., adj.coalesce().values())? This is for the case of computing sparse gradients of adj. (Solved)
  • In __collect__: Should we store the ptr for PyTorch SparseTensor when fused aggregation is not available?
  • In __lift__: Should we use gather_csr for PyTorch SparseTensor?

Also, torch.jit.script is not available for PyTorch SparseTensor. Will figure it out soon.

@EdisonLeeeee
Copy link
Contributor Author

Testing failed due to numerical instability. It seems we should use torch.allclose instead.

@codecov
Copy link

codecov bot commented Nov 10, 2022

Codecov Report

Merging #5944 (9c575b1) into master (c9608f1) will decrease coverage by 1.81%.
The diff coverage is 100.00%.

❗ Current head 9c575b1 differs from pull request most recent head 30f3ed1. Consider uploading reports for the commit 30f3ed1 to get more accurate results

@@            Coverage Diff             @@
##           master    #5944      +/-   ##
==========================================
- Coverage   86.36%   84.54%   -1.82%     
==========================================
  Files         361      361              
  Lines       19849    19872      +23     
==========================================
- Hits        17142    16801     -341     
- Misses       2707     3071     +364     
Impacted Files Coverage Δ
torch_geometric/nn/conv/message_passing.py 98.93% <100.00%> (+0.08%) ⬆️
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) ⬇️
torch_geometric/nn/models/dimenet.py 14.90% <0.00%> (-52.76%) ⬇️
torch_geometric/profile/profile.py 36.73% <0.00%> (-27.56%) ⬇️
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/nn/pool/asap.py 92.10% <0.00%> (-7.90%) ⬇️
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) ⬇️
torch_geometric/transforms/add_self_loops.py 94.44% <0.00%> (-5.56%) ⬇️
torch_geometric/nn/models/attentive_fp.py 95.83% <0.00%> (-4.17%) ⬇️
torch_geometric/nn/resolver.py 86.36% <0.00%> (-3.41%) ⬇️
... and 12 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@rusty1s
Copy link
Member

rusty1s commented Nov 10, 2022

This looks pretty good. Is it ready to review?

@EdisonLeeeee
Copy link
Contributor Author

Yes. Please take a look at your convenience :)

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool. I think we can also think about adding a is_sparse_tensor utility to PyG that checks for either torch.sparse.Tensor or SparseTensor to clean-up some duplicated code in a follow-up.

@rusty1s rusty1s enabled auto-merge (squash) November 15, 2022 12:19
@rusty1s rusty1s merged commit 41fd354 into master Nov 15, 2022
@rusty1s rusty1s deleted the message-passing branch November 15, 2022 12:25
@EdisonLeeeee
Copy link
Contributor Author

Sure :) Will also add it to the roadmap.

jjpietrak pushed a commit to jjpietrak/pytorch_geometric that referenced this pull request Nov 16, 2022
This PR adds the PyTorch SparseTensor support for the base layer
`MessagePassing`. There are some points to be confirmed (as marked with
TODO):
+ ~~In `__collect__`: Since `adj._values()` returns a detached tensor,
should we use a coalesced matrix instead (e.g.,
`adj.coalesce().values()`)? This is for the case of computing sparse
gradients of `adj`.~~ (Solved)
+ In `__collect__`: Should we store the `ptr` for PyTorch SparseTensor
when fused aggregation is not available?
+ In `__lift__`: Should we use `gather_csr` for PyTorch SparseTensor?

Also, `torch.jit.script` is not available for PyTorch SparseTensor. Will
figure it out soon.

Co-authored-by: rusty1s <[email protected]>
jjpietrak pushed a commit to jjpietrak/pytorch_geometric that referenced this pull request Nov 25, 2022
This PR adds the PyTorch SparseTensor support for the base layer
`MessagePassing`. There are some points to be confirmed (as marked with
TODO):
+ ~~In `__collect__`: Since `adj._values()` returns a detached tensor,
should we use a coalesced matrix instead (e.g.,
`adj.coalesce().values()`)? This is for the case of computing sparse
gradients of `adj`.~~ (Solved)
+ In `__collect__`: Should we store the `ptr` for PyTorch SparseTensor
when fused aggregation is not available?
+ In `__lift__`: Should we use `gather_csr` for PyTorch SparseTensor?

Also, `torch.jit.script` is not available for PyTorch SparseTensor. Will
figure it out soon.

Co-authored-by: rusty1s <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants