15
15
Union ,
16
16
)
17
17
18
+ import numpy as np
18
19
import torch
19
20
from torch import Tensor
20
21
from torch_sparse import SparseTensor , coalesce
@@ -265,8 +266,10 @@ def num_nodes(self) -> Optional[int]:
265
266
if 'num_nodes' in self :
266
267
return self ['num_nodes' ]
267
268
for key , value in self .items ():
268
- if isinstance (value , Tensor ) and (key in N_KEYS or 'node' in key ):
269
- return value .size (self ._parent ().__cat_dim__ (key , value , self ))
269
+ if (isinstance (value , (Tensor , np .ndarray ))
270
+ and (key in N_KEYS or 'node' in key )):
271
+ cat_dim = self ._parent ().__cat_dim__ (key , value , self )
272
+ return value .shape [cat_dim ]
270
273
if 'adj' in self and isinstance (self .adj , SparseTensor ):
271
274
return self .adj .size (0 )
272
275
if 'adj_t' in self and isinstance (self .adj_t , SparseTensor ):
@@ -291,7 +294,9 @@ def num_nodes(self) -> Optional[int]:
291
294
292
295
@property
293
296
def num_node_features (self ) -> int :
294
- if 'x' in self and isinstance (self .x , (Tensor , SparseTensor )):
297
+ if 'x' in self and isinstance (self .x , (Tensor , np .ndarray )):
298
+ return 1 if self .x .ndim == 1 else self .x .shape [- 1 ]
299
+ if 'x' in self and isinstance (self .x , SparseTensor ):
295
300
return 1 if self .x .dim () == 1 else self .x .size (- 1 )
296
301
return 0
297
302
@@ -302,9 +307,9 @@ def num_features(self) -> int:
302
307
def is_node_attr (self , key : str ) -> bool :
303
308
value = self [key ]
304
309
cat_dim = self ._parent ().__cat_dim__ (key , value , self )
305
- if not isinstance (value , Tensor ):
310
+ if not isinstance (value , ( Tensor , np . ndarray ) ):
306
311
return False
307
- if value .dim () == 0 or value .size ( cat_dim ) != self .num_nodes :
312
+ if value .ndim == 0 or value .shape [ cat_dim ] != self .num_nodes :
308
313
return False
309
314
return True
310
315
@@ -350,17 +355,19 @@ def edge_index(self) -> Tensor:
350
355
def num_edges (self ) -> int :
351
356
# We sequentially access attributes that reveal the number of edges.
352
357
for key , value in self .items ():
353
- if isinstance (value , Tensor ) and 'edge' in key :
354
- return value .size (self ._parent ().__cat_dim__ (key , value , self ))
358
+ if isinstance (value , (Tensor , np .ndarray )) and 'edge' in key :
359
+ cat_dim = self ._parent ().__cat_dim__ (key , value , self )
360
+ return value .shape [cat_dim ]
355
361
for value in self .values ('adj' , 'adj_t' ):
356
362
if isinstance (value , SparseTensor ):
357
363
return value .nnz ()
358
364
return 0
359
365
360
366
@property
361
367
def num_edge_features (self ) -> int :
362
- if 'edge_attr' in self and isinstance (self .edge_attr , Tensor ):
363
- return 1 if self .edge_attr .dim () == 1 else self .edge_attr .size (- 1 )
368
+ if ('edge_attr' in self and isinstance (self .edge_attr ,
369
+ (Tensor , np .ndarray ))):
370
+ return 1 if self .edge_attr .ndim == 1 else self .edge_attr .shape [- 1 ]
364
371
return 0
365
372
366
373
@property
@@ -386,9 +393,9 @@ def is_node_attr(self, key: str) -> bool:
386
393
def is_edge_attr (self , key : str ) -> bool :
387
394
value = self [key ]
388
395
cat_dim = self ._parent ().__cat_dim__ (key , value , self )
389
- if not isinstance (value , Tensor ):
396
+ if not isinstance (value , ( Tensor , np . ndarray ) ):
390
397
return False
391
- if value .dim () == 0 or value .size ( cat_dim ) != self .num_edges :
398
+ if value .ndim == 0 or value .shape [ cat_dim ] != self .num_edges :
392
399
return False
393
400
return True
394
401
@@ -467,9 +474,9 @@ def is_node_attr(self, key: str) -> bool:
467
474
cat_dim = self ._parent ().__cat_dim__ (key , value , self )
468
475
469
476
num_nodes , num_edges = self .num_nodes , self .num_edges
470
- if not isinstance (value , Tensor ):
477
+ if not isinstance (value , ( Tensor , np . ndarray ) ):
471
478
return False
472
- if value .dim () == 0 or value .size ( cat_dim ) != num_nodes :
479
+ if value .ndim == 0 or value .shape [ cat_dim ] != num_nodes :
473
480
return False
474
481
if num_nodes != num_edges :
475
482
return True
@@ -480,9 +487,9 @@ def is_edge_attr(self, key: str) -> bool:
480
487
cat_dim = self ._parent ().__cat_dim__ (key , value , self )
481
488
482
489
num_nodes , num_edges = self .num_nodes , self .num_edges
483
- if not isinstance (value , Tensor ):
490
+ if not isinstance (value , ( Tensor , np . ndarray ) ):
484
491
return False
485
- if value .dim () == 0 or value .size ( cat_dim ) != num_edges :
492
+ if value .ndim == 0 or value .shape [ cat_dim ] != num_edges :
486
493
return False
487
494
if num_nodes != num_edges :
488
495
return True
0 commit comments