Skip to content

Commit ba039e4

Browse files
EdisonLeeeeerusty1spre-commit-ci[bot]
authored
Add dropout_node to torch_geometric.utils (#5481)
* add dropout_node * pass num_nodes to subgraph * drop relabel_nodes argument * Update torch_geometric/utils/dropout.py Co-authored-by: Matthias Fey <[email protected]> * Update torch_geometric/utils/dropout.py Co-authored-by: Matthias Fey <[email protected]> * Update torch_geometric/utils/dropout.py Co-authored-by: Matthias Fey <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * test * changelog Co-authored-by: Matthias Fey <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d3f35cb commit ba039e4

File tree

4 files changed

+78
-2
lines changed

4 files changed

+78
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.2.0] - 2022-MM-DD
77
### Added
8+
- Added `dropout_node` augmentation that randomly drops nodes from a graph ([#5481](https://github.com/pyg-team/pytorch_geometric/pull/5481))
89
- Added `AddRandomMetaPaths` that adds edges based on random walks along a metapath ([#5397](https://github.com/pyg-team/pytorch_geometric/pull/5397))
910
- Added `WLConvContinuous` for performing WL refinement with continuous attributes ([#5316](https://github.com/pyg-team/pytorch_geometric/pull/5316))
1011
- Added `print_summary` method for the `torch_geometric.data.Dataset` interface ([#5438](https://github.com/pyg-team/pytorch_geometric/pull/5438))

test/utils/test_dropout.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from torch_geometric.utils import dropout_adj
3+
from torch_geometric.utils import dropout_adj, dropout_node
44

55

66
def test_dropout_adj():
@@ -23,3 +23,21 @@ def test_dropout_adj():
2323
out = dropout_adj(edge_index, edge_attr, force_undirected=True)
2424
assert out[0].tolist() == [[0, 1, 1, 2], [1, 2, 0, 1]]
2525
assert out[1].tolist() == [1, 3, 1, 3]
26+
27+
28+
def test_dropout_node():
29+
edge_index = torch.tensor([
30+
[0, 1, 1, 2, 2, 3],
31+
[1, 0, 2, 1, 3, 2],
32+
])
33+
34+
out = dropout_node(edge_index, training=False)
35+
assert edge_index.tolist() == out[0].tolist()
36+
assert out[1].tolist() == [False, False, False, False, False, False]
37+
assert out[2].tolist() == [False, False, False, False]
38+
39+
torch.manual_seed(5)
40+
out = dropout_node(edge_index)
41+
assert out[0].tolist() == [[2, 3], [3, 2]]
42+
assert out[1].tolist() == [False, False, False, False, True, True]
43+
assert out[2].tolist() == [True, False, True, True]

torch_geometric/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .degree import degree
22
from .softmax import softmax
3-
from .dropout import dropout_adj
3+
from .dropout import dropout_adj, dropout_node
44
from .sort_edge_index import sort_edge_index
55
from .coalesce import coalesce
66
from .undirected import is_undirected, to_undirected
@@ -39,6 +39,7 @@
3939
'degree',
4040
'softmax',
4141
'dropout_adj',
42+
'dropout_node',
4243
'sort_edge_index',
4344
'coalesce',
4445
'is_undirected',

torch_geometric/utils/dropout.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
from torch_geometric.typing import OptTensor
77

8+
from .num_nodes import maybe_num_nodes
9+
from .subgraph import subgraph
10+
811

912
def filter_adj(row: Tensor, col: Tensor, edge_attr: OptTensor,
1013
mask: Tensor) -> Tuple[Tensor, Tensor, OptTensor]:
@@ -79,3 +82,56 @@ def dropout_adj(
7982
edge_index = torch.stack([row, col], dim=0)
8083

8184
return edge_index, edge_attr
85+
86+
87+
def dropout_node(edge_index: Tensor, p: float = 0.5,
88+
num_nodes: Optional[int] = None,
89+
training: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
90+
r"""Randomly drops nodes from the adjacency matrix
91+
:obj:`edge_index` with probability :obj:`p` using samples from
92+
a Bernoulli distribution.
93+
94+
The method returns (1) the retained :obj:`edge_index`, (2) the edge mask
95+
indicating which edges were dropped. (3) the node mask indicating
96+
which nodes were dropped.
97+
98+
Args:
99+
edge_index (LongTensor): The edge indices.
100+
p (float, optional): Dropout probability. (default: :obj:`0.5`)
101+
num_nodes (int, optional): The number of nodes, *i.e.*
102+
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
103+
training (bool, optional): If set to :obj:`False`, this operation is a
104+
no-op. (default: :obj:`True`)
105+
106+
:rtype: (:class:`LongTensor`, :class:`BoolTensor`, :class:`BoolTensor`)
107+
108+
Examples:
109+
110+
>>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],
111+
... [1, 0, 2, 1, 3, 2]])
112+
>>> edge_index, edge_mask, node_mask = dropout_node(edge_index)
113+
>>> edge_index
114+
tensor([[0, 1],
115+
[1, 0]])
116+
>>> edge_mask
117+
tensor([ True, True, False, False, False, False])
118+
>>> node_mask
119+
tensor([ True, True, False, False])
120+
"""
121+
if p < 0. or p > 1.:
122+
raise ValueError(f'Dropout probability has to be between 0 and 1 '
123+
f'(got {p}')
124+
125+
num_nodes = maybe_num_nodes(edge_index, num_nodes)
126+
127+
if not training or p == 0.0:
128+
node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)
129+
edge_mask = edge_index.new_zeros(edge_index.size(1), dtype=torch.bool)
130+
return edge_index, edge_mask, node_mask
131+
132+
prob = torch.rand(num_nodes, device=edge_index.device)
133+
node_mask = prob > p
134+
edge_index, _, edge_mask = subgraph(node_mask, edge_index,
135+
num_nodes=num_nodes,
136+
return_edge_mask=True)
137+
return edge_index, edge_mask, node_mask

0 commit comments

Comments
 (0)