@@ -101,9 +101,12 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
101
101
ptr : Optional [Tensor ] = None , dim_size : Optional [int ] = None ,
102
102
dim : int = - 2 ) -> Tensor :
103
103
mean = self .reduce (x , index , ptr , dim_size , dim , reduce = 'mean' )
104
- x2 = x .detach () * x .detach () if self .semi_grad else x * x
105
- mean_2 = self .reduce (x2 , index , ptr , dim_size , dim , 'mean' )
106
- return mean_2 - mean * mean
104
+ if self .semi_grad :
105
+ with torch .no_grad ():
106
+ mean2 = self .reduce (x * x , index , ptr , dim_size , dim , 'mean' )
107
+ else :
108
+ mean2 = self .reduce (x * x , index , ptr , dim_size , dim , 'mean' )
109
+ return mean2 - mean * mean
107
110
108
111
109
112
class StdAggregation (Aggregation ):
@@ -200,9 +203,10 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
200
203
alpha = x * t
201
204
202
205
if not self .learn and self .semi_grad :
203
- alpha = alpha .detach ()
204
- alpha = softmax (alpha , index , ptr , dim_size , dim )
205
- alpha = softmax (alpha , index , ptr , dim_size , dim )
206
+ with torch .no_grad ():
207
+ alpha = softmax (alpha , index , ptr , dim_size , dim )
208
+ else :
209
+ alpha = softmax (alpha , index , ptr , dim_size , dim )
206
210
return self .reduce (x * alpha , index , ptr , dim_size , dim , reduce = 'sum' )
207
211
208
212
def __repr__ (self ) -> str :
0 commit comments