Skip to content

Commit 4f98778

Browse files
committed
Update
1 parent 1481cdc commit 4f98778

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

examples/graph_gps.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
parser = argparse.ArgumentParser()
3333
parser.add_argument(
3434
'--attn_type', default='multihead',
35-
help="Global attention type such as multihead or performer.")
35+
help="Global attention type such as 'multihead' or 'performer'.")
3636
args = parser.parse_args()
3737

3838

@@ -79,13 +79,12 @@ def forward(self, x, pe, edge_index, edge_attr, batch):
7979
return self.mlp(x)
8080

8181

82-
class RedrawProjection(torch.nn.Module):
82+
class RedrawProjection:
8383
def __init__(self, model: torch.nn.Module,
8484
redraw_interval: Optional[int] = None):
85-
super().__init__()
8685
self.model = model
8786
self.redraw_interval = redraw_interval
88-
self.register_buffer('num_last_redraw', torch.tensor(0))
87+
self.num_last_redraw = 0
8988

9089
def redraw_projections(self):
9190
if not self.model.training or self.redraw_interval is None:
@@ -97,7 +96,7 @@ def redraw_projections(self):
9796
]
9897
for fast_attention in fast_attentions:
9998
fast_attention.redraw_projection_matrix()
100-
self.num_last_redraw.zero_()
99+
self.num_last_redraw = 0
101100
return
102101
self.num_last_redraw += 1
103102

0 commit comments

Comments
 (0)