@@ -85,7 +85,51 @@ def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:
85
85
86
86
87
87
class Index (Tensor ):
88
- r"""TODO."""
88
+ r"""A one-dimensional :obj:`index` tensor with additional (meta)data
89
+ attached.
90
+
91
+ :class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds
92
+ indices of shape :obj:`[num_indices]`.
93
+
94
+ While :class:`Index` sub-classes a general :pytorch:`null`
95
+ :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:
96
+
97
+ * :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,
98
+ the size of a dimension that can be indexed via :obj:`index`.
99
+ By default, it is inferred as :obj:`dim_size=index.max() + 1`.
100
+ * :obj:`is_sorted`: Whether indices are sorted in ascending order.
101
+
102
+ Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR
103
+ conversion in case its representation is sorted.
104
+ Caches are filled based on demand (*e.g.*, when calling
105
+ :meth:`Index.get_indptr`), or when explicitly requested via
106
+ :meth:`Index.fill_cache_`, and are maintaned and adjusted over its
107
+ lifespan.
108
+
109
+ This representation ensures for optimal computation in GNN message passing
110
+ schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
111
+ workflows.
112
+
113
+ .. code-block:: python
114
+
115
+ from torch_geometric import Index
116
+
117
+ index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
118
+ >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
119
+ assert index.dim_size == 3
120
+ assert index.is_sorted
121
+
122
+ # Flipping order:
123
+ edge_index.flip(0)
124
+ >>> Index([[2, 1, 1, 0], dim_size=3)
125
+ assert not index.is_sorted
126
+
127
+ # Filtering:
128
+ mask = torch.tensor([True, True, True, False])
129
+ index[:, mask]
130
+ >>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
131
+ assert index.is_sorted
132
+ """
89
133
# See "https://pytorch.org/docs/stable/notes/extending.html"
90
134
# for a basic tutorial on how to subclass `torch.Tensor`.
91
135
@@ -166,7 +210,13 @@ def __new__(
166
210
# Validation ##############################################################
167
211
168
212
def validate (self ) -> 'Index' :
169
- r"""TODO."""
213
+ r"""Validates the :class:`Index` representation.
214
+
215
+ In particular, it ensures that
216
+
217
+ * it only holds valid indices.
218
+ * the sort order is correctly set.
219
+ """
170
220
assert_valid_dtype (self ._data )
171
221
assert_one_dimensional (self ._data )
172
222
assert_contiguous (self ._data )
@@ -191,12 +241,12 @@ def validate(self) -> 'Index':
191
241
192
242
@property
193
243
def dim_size (self ) -> Optional [int ]:
194
- r"""TODO ."""
244
+ r"""The size of the underlying sparse vector ."""
195
245
return self ._dim_size
196
246
197
247
@property
198
248
def is_sorted (self ) -> bool :
199
- r"""TODO ."""
249
+ r"""Returns whether indices are sorted in ascending order ."""
200
250
return self ._is_sorted
201
251
202
252
@property
@@ -207,7 +257,9 @@ def dtype(self) -> torch.dtype: # type: ignore
207
257
# Cache Interface #########################################################
208
258
209
259
def get_dim_size (self ) -> int :
210
- r"""TODO."""
260
+ r"""The size of the underlying sparse vector.
261
+ Automatically computed and cached when not explicitly set.
262
+ """
211
263
if self ._dim_size is None :
212
264
dim_size = int (self ._data .max ()) + 1 if self .numel () > 0 else 0
213
265
self ._dim_size = dim_size
@@ -216,7 +268,7 @@ def get_dim_size(self) -> int:
216
268
return self ._dim_size
217
269
218
270
def dim_resize_ (self , dim_size : Optional [int ]) -> 'Index' :
219
- r"""TODO ."""
271
+ r"""Assigns or re-assigns the size of the underlying sparse vector ."""
220
272
if self .is_sorted and self ._indptr is not None :
221
273
if dim_size is None :
222
274
self ._indptr = None
@@ -237,15 +289,17 @@ def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
237
289
238
290
@assert_sorted
239
291
def get_indptr (self ) -> Tensor :
240
- r"""TODO."""
292
+ r"""Returns the compressed index representation in case :class:`Index`
293
+ is sorted.
294
+ """
241
295
if self ._indptr is None :
242
296
self ._indptr = index2ptr (self ._data , self .get_dim_size ())
243
297
244
298
assert isinstance (self ._indptr , Tensor )
245
299
return self ._indptr
246
300
247
301
def fill_cache_ (self ) -> 'Index' :
248
- r"""TODO ."""
302
+ r"""Fills the cache with (meta)data information ."""
249
303
self .get_dim_size ()
250
304
251
305
if self .is_sorted :
@@ -256,12 +310,14 @@ def fill_cache_(self) -> 'Index':
256
310
# Methods #################################################################
257
311
258
312
def share_memory_ (self ) -> 'Index' :
313
+ """""" # noqa: D419
259
314
self ._data .share_memory_ ()
260
315
if self ._indptr is not None :
261
316
self ._indptr .share_memory_ ()
262
317
return self
263
318
264
319
def is_shared (self ) -> bool :
320
+ """""" # noqa: D419
265
321
return self ._data .is_shared ()
266
322
267
323
def as_tensor (self ) -> Tensor :
0 commit comments