-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Add Performer
to GPSConv
#7465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Performer
to GPSConv
#7465
Conversation
Codecov Report
@@ Coverage Diff @@
## master #7465 +/- ##
==========================================
- Coverage 91.70% 91.40% -0.30%
==========================================
Files 447 449 +2
Lines 24925 25024 +99
==========================================
+ Hits 22857 22874 +17
- Misses 2068 2150 +82
... and 17 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some initial comments. Will take a look at it again.
Also, what sort of speedups do you see using Performer
, it'll be nice to benchmark this.
torch_geometric/nn/conv/gps_conv.py
Outdated
return out | ||
|
||
|
||
class PerformerAttention(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zechengz , can PerformerAttention
be used by any PyG model that needs multihead attention. If so can we move this to some common place like torch_geometric/nn/models
or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that moving to torch_geometric/nn/models
will cause some circular import. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. Since we plan to add more attention modules. Lets create a folder called torch/nn/attention
, and put all the attention modules there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zechengz I am not sure how many attention models we plan to add. But if there are potentially several out there and multiple pyg modules could use them, we should define a Attention
base class and have all the attention modules we add use the same interface, similar to torch_geometric/nn/aggr
. With this the user can simply pass any attention module as an argument to the GNN modules that support attention. WDYT? cc: @rusty1s
I don't want to gate this PR, with the above suggestion though. I've left some more comments and it looks good.
torch_geometric/nn/conv/gps_conv.py
Outdated
return out | ||
|
||
|
||
class PerformerAttention(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. Since we plan to add more attention modules. Lets create a folder called torch/nn/attention
, and put all the attention modules there.
@wsad1 Current plan is to add this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @zechengz !
Good to merge after the comments are addressed.
Co-authored-by: Jinu Sunil <[email protected]>
for more information, see https://pre-commit.ci
Description
examples/graph_gps.py
. Not sure where should I put it.ZINC Example
I directly replace the multihead attention with performer fast attention with the
ReLU
kernel. The performance is similar (there are some randomness for each run and for the graph classification task) Following is thetest_mae
plot for each epoch.