Skip to content

Add documentation to torch_geometric.Index #9297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
Expand Down
3 changes: 2 additions & 1 deletion docs/source/modules/root.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ torch_geometric
Tensor Objects
--------------

.. currentmodule:: torch_geometric.edge_index
.. currentmodule:: torch_geometric

.. autosummary::
:nosignatures:
:toctree: ../generated

Index
EdgeIndex

Functions
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex':
# Methods #################################################################

def share_memory_(self) -> 'EdgeIndex':
"""""" # noqa: D419
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
Expand All @@ -714,6 +715,7 @@ def share_memory_(self) -> 'EdgeIndex':
return self

def is_shared(self) -> bool:
"""""" # noqa: D419
return self._data.is_shared()

def as_tensor(self) -> Tensor:
Expand Down
72 changes: 64 additions & 8 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,51 @@ def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:


class Index(Tensor):
r"""TODO."""
r"""A one-dimensional :obj:`index` tensor with additional (meta)data
attached.

:class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds
indices of shape :obj:`[num_indices]`.

While :class:`Index` sub-classes a general :pytorch:`null`
:class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:

* :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,
the size of a dimension that can be indexed via :obj:`index`.
By default, it is inferred as :obj:`dim_size=index.max() + 1`.
* :obj:`is_sorted`: Whether indices are sorted in ascending order.

Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR
conversion in case its representation is sorted.
Caches are filled based on demand (*e.g.*, when calling
:meth:`Index.get_indptr`), or when explicitly requested via
:meth:`Index.fill_cache_`, and are maintaned and adjusted over its
lifespan.

This representation ensures for optimal computation in GNN message passing
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
workflows.

.. code-block:: python

from torch_geometric import Index

index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
>>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)
assert index.dim_size == 3
assert index.is_sorted

# Flipping order:
edge_index.flip(0)
>>> Index([[2, 1, 1, 0], dim_size=3)
assert not index.is_sorted

# Filtering:
mask = torch.tensor([True, True, True, False])
index[:, mask]
>>> Index([[0, 1, 1], dim_size=3, is_sorted=True)
assert index.is_sorted
"""
# See "https://pytorch.org/docs/stable/notes/extending.html"
# for a basic tutorial on how to subclass `torch.Tensor`.

Expand Down Expand Up @@ -166,7 +210,13 @@ def __new__(
# Validation ##############################################################

def validate(self) -> 'Index':
r"""TODO."""
r"""Validates the :class:`Index` representation.

In particular, it ensures that

* it only holds valid indices.
* the sort order is correctly set.
"""
assert_valid_dtype(self._data)
assert_one_dimensional(self._data)
assert_contiguous(self._data)
Expand All @@ -191,12 +241,12 @@ def validate(self) -> 'Index':

@property
def dim_size(self) -> Optional[int]:
r"""TODO."""
r"""The size of the underlying sparse vector."""
return self._dim_size

@property
def is_sorted(self) -> bool:
r"""TODO."""
r"""Returns whether indices are sorted in ascending order."""
return self._is_sorted

@property
Expand All @@ -207,7 +257,9 @@ def dtype(self) -> torch.dtype: # type: ignore
# Cache Interface #########################################################

def get_dim_size(self) -> int:
r"""TODO."""
r"""The size of the underlying sparse vector.
Automatically computed and cached when not explicitly set.
"""
if self._dim_size is None:
dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0
self._dim_size = dim_size
Expand All @@ -216,7 +268,7 @@ def get_dim_size(self) -> int:
return self._dim_size

def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
r"""TODO."""
r"""Assigns or re-assigns the size of the underlying sparse vector."""
if self.is_sorted and self._indptr is not None:
if dim_size is None:
self._indptr = None
Expand All @@ -237,15 +289,17 @@ def dim_resize_(self, dim_size: Optional[int]) -> 'Index':

@assert_sorted
def get_indptr(self) -> Tensor:
r"""TODO."""
r"""Returns the compressed index representation in case :class:`Index`
is sorted.
"""
if self._indptr is None:
self._indptr = index2ptr(self._data, self.get_dim_size())

assert isinstance(self._indptr, Tensor)
return self._indptr

def fill_cache_(self) -> 'Index':
r"""TODO."""
r"""Fills the cache with (meta)data information."""
self.get_dim_size()

if self.is_sorted:
Expand All @@ -256,12 +310,14 @@ def fill_cache_(self) -> 'Index':
# Methods #################################################################

def share_memory_(self) -> 'Index':
"""""" # noqa: D419
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
return self

def is_shared(self) -> bool:
"""""" # noqa: D419
return self._data.is_shared()

def as_tensor(self) -> Tensor:
Expand Down