Skip to content

Commit 3309f7d

Browse files
committed
update
1 parent e909b52 commit 3309f7d

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

torch_geometric/nn/aggr/basic.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,12 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
101101
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
102102
dim: int = -2) -> Tensor:
103103
mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean')
104-
x2 = x.detach() * x.detach() if self.semi_grad else x * x
105-
mean_2 = self.reduce(x2, index, ptr, dim_size, dim, 'mean')
106-
return mean_2 - mean * mean
104+
if self.semi_grad:
105+
with torch.no_grad():
106+
mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean')
107+
else:
108+
mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean')
109+
return mean2 - mean * mean
107110

108111

109112
class StdAggregation(Aggregation):
@@ -200,9 +203,10 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
200203
alpha = x * t
201204

202205
if not self.learn and self.semi_grad:
203-
alpha = alpha.detach()
204-
alpha = softmax(alpha, index, ptr, dim_size, dim)
205-
alpha = softmax(alpha, index, ptr, dim_size, dim)
206+
with torch.no_grad():
207+
alpha = softmax(alpha, index, ptr, dim_size, dim)
208+
else:
209+
alpha = softmax(alpha, index, ptr, dim_size, dim)
206210
return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum')
207211

208212
def __repr__(self) -> str:

torch_geometric/nn/aggr/fused.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,17 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
251251
# `include_self=True` + manual masking leads to faster runtime:
252252
out = x.new_full((dim_size, num_feats), fill_value)
253253

254-
src = x
255254
if reduce == 'pow_sum':
256255
reduce = 'sum'
257-
src = x.detach() * x.detach() if self.semi_grad else x * x
258-
out.scatter_reduce_(0, index, src, reduce, include_self=True)
256+
if self.semi_grad:
257+
with torch.no_grad():
258+
out.scatter_reduce_(0, index, x * x, reduce,
259+
include_self=True)
260+
else:
261+
out.scatter_reduce_(0, index, x * x, reduce,
262+
include_self=True)
263+
else:
264+
out.scatter_reduce_(0, index, x, reduce, include_self=True)
259265

260266
if fill_value != 0.0:
261267
assert mask is not None

0 commit comments

Comments
 (0)