Skip to content

Commit 7d0fcd2

Browse files
authored
Add documentation to torch_geometric.Index (#9297)
1 parent 8bb44ed commit 7d0fcd2

File tree

4 files changed

+69
-10
lines changed

4 files changed

+69
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88
### Added
99

1010
- Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291))
11-
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289))
11+
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297))
1212
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
1313
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
1414
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))

docs/source/modules/root.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ torch_geometric
44
Tensor Objects
55
--------------
66

7-
.. currentmodule:: torch_geometric.edge_index
7+
.. currentmodule:: torch_geometric
88

99
.. autosummary::
1010
:nosignatures:
1111
:toctree: ../generated
1212

13+
Index
1314
EdgeIndex
1415

1516
Functions

torch_geometric/edge_index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,7 @@ def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex':
698698
# Methods #################################################################
699699

700700
def share_memory_(self) -> 'EdgeIndex':
701+
"""""" # noqa: D419
701702
self._data.share_memory_()
702703
if self._indptr is not None:
703704
self._indptr.share_memory_()
@@ -714,6 +715,7 @@ def share_memory_(self) -> 'EdgeIndex':
714715
return self
715716

716717
def is_shared(self) -> bool:
718+
"""""" # noqa: D419
717719
return self._data.is_shared()
718720

719721
def as_tensor(self) -> Tensor:

torch_geometric/index.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,51 @@ def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:
8585

8686

8787
class Index(Tensor):
88-
r"""TODO."""
88+
r"""A one-dimensional :obj:`index` tensor with additional (meta)data
89+
attached.
90+
91+
:class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds
92+
indices of shape :obj:`[num_indices]`.
93+
94+
While :class:`Index` sub-classes a general :pytorch:`null`
95+
:class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:
96+
97+
* :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,
98+
the size of a dimension that can be indexed via :obj:`index`.
99+
By default, it is inferred as :obj:`dim_size=index.max() + 1`.
100+
* :obj:`is_sorted`: Whether indices are sorted in ascending order.
101+
102+
Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR
103+
conversion in case its representation is sorted.
104+
Caches are filled based on demand (*e.g.*, when calling
105+
:meth:`Index.get_indptr`), or when explicitly requested via
106+
:meth:`Index.fill_cache_`, and are maintaned and adjusted over its
107+
lifespan.
108+
109+
This representation ensures for optimal computation in GNN message passing
110+
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
111+
workflows.
112+
113+
.. code-block:: python
114+
115+
from torch_geometric import Index
116+
117+
index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
118+
>>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
119+
assert index.dim_size == 3
120+
assert index.is_sorted
121+
122+
# Flipping order:
123+
edge_index.flip(0)
124+
>>> Index([[2, 1, 1, 0], dim_size=3)
125+
assert not index.is_sorted
126+
127+
# Filtering:
128+
mask = torch.tensor([True, True, True, False])
129+
index[:, mask]
130+
>>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
131+
assert index.is_sorted
132+
"""
89133
# See "https://pytorch.org/docs/stable/notes/extending.html"
90134
# for a basic tutorial on how to subclass `torch.Tensor`.
91135

@@ -166,7 +210,13 @@ def __new__(
166210
# Validation ##############################################################
167211

168212
def validate(self) -> 'Index':
169-
r"""TODO."""
213+
r"""Validates the :class:`Index` representation.
214+
215+
In particular, it ensures that
216+
217+
* it only holds valid indices.
218+
* the sort order is correctly set.
219+
"""
170220
assert_valid_dtype(self._data)
171221
assert_one_dimensional(self._data)
172222
assert_contiguous(self._data)
@@ -191,12 +241,12 @@ def validate(self) -> 'Index':
191241

192242
@property
193243
def dim_size(self) -> Optional[int]:
194-
r"""TODO."""
244+
r"""The size of the underlying sparse vector."""
195245
return self._dim_size
196246

197247
@property
198248
def is_sorted(self) -> bool:
199-
r"""TODO."""
249+
r"""Returns whether indices are sorted in ascending order."""
200250
return self._is_sorted
201251

202252
@property
@@ -207,7 +257,9 @@ def dtype(self) -> torch.dtype: # type: ignore
207257
# Cache Interface #########################################################
208258

209259
def get_dim_size(self) -> int:
210-
r"""TODO."""
260+
r"""The size of the underlying sparse vector.
261+
Automatically computed and cached when not explicitly set.
262+
"""
211263
if self._dim_size is None:
212264
dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0
213265
self._dim_size = dim_size
@@ -216,7 +268,7 @@ def get_dim_size(self) -> int:
216268
return self._dim_size
217269

218270
def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
219-
r"""TODO."""
271+
r"""Assigns or re-assigns the size of the underlying sparse vector."""
220272
if self.is_sorted and self._indptr is not None:
221273
if dim_size is None:
222274
self._indptr = None
@@ -237,15 +289,17 @@ def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
237289

238290
@assert_sorted
239291
def get_indptr(self) -> Tensor:
240-
r"""TODO."""
292+
r"""Returns the compressed index representation in case :class:`Index`
293+
is sorted.
294+
"""
241295
if self._indptr is None:
242296
self._indptr = index2ptr(self._data, self.get_dim_size())
243297

244298
assert isinstance(self._indptr, Tensor)
245299
return self._indptr
246300

247301
def fill_cache_(self) -> 'Index':
248-
r"""TODO."""
302+
r"""Fills the cache with (meta)data information."""
249303
self.get_dim_size()
250304

251305
if self.is_sorted:
@@ -256,12 +310,14 @@ def fill_cache_(self) -> 'Index':
256310
# Methods #################################################################
257311

258312
def share_memory_(self) -> 'Index':
313+
"""""" # noqa: D419
259314
self._data.share_memory_()
260315
if self._indptr is not None:
261316
self._indptr.share_memory_()
262317
return self
263318

264319
def is_shared(self) -> bool:
320+
"""""" # noqa: D419
265321
return self._data.is_shared()
266322

267323
def as_tensor(self) -> Tensor:

0 commit comments

Comments
 (0)