-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Add PyTorch SparseTensor support for MessagePassing
#5944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Testing failed due to numerical instability. It seems we should use |
Codecov Report
@@ Coverage Diff @@
## master #5944 +/- ##
==========================================
- Coverage 86.36% 84.54% -1.82%
==========================================
Files 361 361
Lines 19849 19872 +23
==========================================
- Hits 17142 16801 -341
- Misses 2707 3071 +364
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
This looks pretty good. Is it ready to review? |
Yes. Please take a look at your convenience :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is super cool. I think we can also think about adding a is_sparse_tensor
utility to PyG that checks for either torch.sparse.Tensor
or SparseTensor
to clean-up some duplicated code in a follow-up.
Sure :) Will also add it to the roadmap. |
This PR adds the PyTorch SparseTensor support for the base layer `MessagePassing`. There are some points to be confirmed (as marked with TODO): + ~~In `__collect__`: Since `adj._values()` returns a detached tensor, should we use a coalesced matrix instead (e.g., `adj.coalesce().values()`)? This is for the case of computing sparse gradients of `adj`.~~ (Solved) + In `__collect__`: Should we store the `ptr` for PyTorch SparseTensor when fused aggregation is not available? + In `__lift__`: Should we use `gather_csr` for PyTorch SparseTensor? Also, `torch.jit.script` is not available for PyTorch SparseTensor. Will figure it out soon. Co-authored-by: rusty1s <[email protected]>
This PR adds the PyTorch SparseTensor support for the base layer `MessagePassing`. There are some points to be confirmed (as marked with TODO): + ~~In `__collect__`: Since `adj._values()` returns a detached tensor, should we use a coalesced matrix instead (e.g., `adj.coalesce().values()`)? This is for the case of computing sparse gradients of `adj`.~~ (Solved) + In `__collect__`: Should we store the `ptr` for PyTorch SparseTensor when fused aggregation is not available? + In `__lift__`: Should we use `gather_csr` for PyTorch SparseTensor? Also, `torch.jit.script` is not available for PyTorch SparseTensor. Will figure it out soon. Co-authored-by: rusty1s <[email protected]>
This PR adds the PyTorch SparseTensor support for the base layer
MessagePassing
. There are some points to be confirmed (as marked with TODO):In(Solved)__collect__
: Sinceadj._values()
returns a detached tensor, should we use a coalesced matrix instead (e.g.,adj.coalesce().values()
)? This is for the case of computing sparse gradients ofadj
.__collect__
: Should we store theptr
for PyTorch SparseTensor when fused aggregation is not available?__lift__
: Should we usegather_csr
for PyTorch SparseTensor?Also,
torch.jit.script
is not available for PyTorch SparseTensor. Will figure it out soon.