Skip to content

Commit 10b7373

Browse files
authored
Only add true negatives in add_random_edge augmentation (#7654)
Fixes #7653
1 parent 2ec1c4b commit 10b7373

File tree

5 files changed

+60
-53
lines changed

5 files changed

+60
-53
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7070

7171
### Changed
7272

73+
- Changed `add_random_edge` to only add true negative edges ([#7654](https://github.com/pyg-team/pytorch_geometric/pull/7654))
7374
- Allowed the usage of `BasicGNN` models in `DeepGraphInfomax` ([#7648](https://github.com/pyg-team/pytorch_geometric/pull/7648))
7475
- Breaking Change: Made `Data.keys` a method rather than a property ([#7629](https://github.com/pyg-team/pytorch_geometric/pull/7629))
7576
- Added a `num_edges` parameter to the forward method of `HypergraphConv` ([#7560](https://github.com/pyg-team/pytorch_geometric/pull/7560))

test/utils/test_augmentation.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import torch
33

4+
from torch_geometric import seed_everything
45
from torch_geometric.utils import (
56
add_random_edge,
67
is_undirected,
@@ -77,28 +78,26 @@ def test_add_random_edge():
7778
assert out[0].tolist() == edge_index.tolist()
7879
assert out[1].tolist() == [[], []]
7980

80-
torch.manual_seed(5)
81+
seed_everything(5)
8182
out = add_random_edge(edge_index, p=0.5)
82-
assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 3, 2, 3],
83-
[1, 0, 2, 1, 3, 2, 1, 2, 2]]
84-
85-
assert out[1].tolist() == [[3, 2, 3], [1, 2, 2]]
83+
assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 3, 1, 2],
84+
[1, 0, 2, 1, 3, 2, 0, 3, 0]]
85+
assert out[1].tolist() == [[3, 1, 2], [0, 3, 0]]
8686

87-
torch.manual_seed(6)
87+
seed_everything(6)
8888
out = add_random_edge(edge_index, p=0.5, force_undirected=True)
89-
assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 1, 2],
90-
[1, 0, 2, 1, 3, 2, 2, 1]]
91-
assert out[1].tolist() == [[1, 2], [2, 1]]
89+
assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 1, 3],
90+
[1, 0, 2, 1, 3, 2, 3, 1]]
91+
assert out[1].tolist() == [[1, 3], [3, 1]]
9292
assert is_undirected(out[0])
9393
assert is_undirected(out[1])
9494

95-
# test with bipartite graph
96-
torch.manual_seed(7)
95+
# Test for bipartite graph:
96+
seed_everything(7)
9797
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], [2, 3, 1, 4, 2, 1]])
98-
with pytest.raises(RuntimeError,
99-
match="not supported for heterogeneous graphs"):
100-
out = add_random_edge(edge_index, p=0.5, force_undirected=True,
101-
num_nodes=(6, 5))
98+
with pytest.raises(RuntimeError, match="not supported for bipartite"):
99+
add_random_edge(edge_index, force_undirected=True, num_nodes=(6, 5))
102100
out = add_random_edge(edge_index, p=0.5, num_nodes=(6, 5))
103-
out[0].tolist() == [[0, 1, 2, 3, 4, 5, 3, 4, 1],
104-
[2, 3, 1, 4, 2, 1, 1, 3, 2]]
101+
assert out[0].tolist() == [[0, 1, 2, 3, 4, 5, 2, 0, 2],
102+
[2, 3, 1, 4, 2, 1, 0, 4, 2]]
103+
assert out[1].tolist() == [[2, 0, 2], [0, 4, 2]]

torch_geometric/utils/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from .degree import degree
77
from .softmax import softmax
88
from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path
9-
from .augmentation import shuffle_node, mask_feature, add_random_edge
109
from .sort_edge_index import sort_edge_index
1110
from .coalesce import coalesce
1211
from .undirected import is_undirected, to_undirected
@@ -47,6 +46,7 @@
4746
from .negative_sampling import (negative_sampling, batched_negative_sampling,
4847
structured_negative_sampling,
4948
structured_negative_sampling_feasible)
49+
from .augmentation import shuffle_node, mask_feature, add_random_edge
5050
from .tree_decomposition import tree_decomposition
5151
from .embedding import get_embeddings
5252
from .trim_to_layer import trim_to_layer
@@ -62,9 +62,6 @@
6262
'dropout_edge',
6363
'dropout_path',
6464
'dropout_adj',
65-
'shuffle_node',
66-
'mask_feature',
67-
'add_random_edge',
6865
'sort_edge_index',
6966
'coalesce',
7067
'is_undirected',
@@ -130,6 +127,9 @@
130127
'batched_negative_sampling',
131128
'structured_negative_sampling',
132129
'structured_negative_sampling_feasible',
130+
'shuffle_node',
131+
'mask_feature',
132+
'add_random_edge',
133133
'tree_decomposition',
134134
'get_embeddings',
135135
'trim_to_layer',

torch_geometric/utils/augmentation.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import torch
44
from torch import Tensor
55

6-
from torch_geometric.utils import scatter
7-
from torch_geometric.utils.num_nodes import maybe_num_nodes
6+
from torch_geometric.utils import negative_sampling, scatter
87

98

10-
def shuffle_node(x: Tensor, batch: Optional[Tensor] = None,
11-
training: bool = True) -> Tuple[Tensor, Tensor]:
9+
def shuffle_node(
10+
x: Tensor,
11+
batch: Optional[Tensor] = None,
12+
training: bool = True,
13+
) -> Tuple[Tensor, Tensor]:
1214
r"""Randomly shuffle the feature matrix :obj:`x` along the
1315
first dimmension.
1416
@@ -67,9 +69,13 @@ def shuffle_node(x: Tensor, batch: Optional[Tensor] = None,
6769
return x[perm], perm
6870

6971

70-
def mask_feature(x: Tensor, p: float = 0.5, mode: str = 'col',
71-
fill_value: float = 0.,
72-
training: bool = True) -> Tuple[Tensor, Tensor]:
72+
def mask_feature(
73+
x: Tensor,
74+
p: float = 0.5,
75+
mode: str = 'col',
76+
fill_value: float = 0.,
77+
training: bool = True,
78+
) -> Tuple[Tensor, Tensor]:
7379
r"""Randomly masks feature from the feature matrix
7480
:obj:`x` with probability :obj:`p` using samples from
7581
a Bernoulli distribution.
@@ -149,9 +155,13 @@ def mask_feature(x: Tensor, p: float = 0.5, mode: str = 'col',
149155
return x, mask
150156

151157

152-
def add_random_edge(edge_index, p: float, force_undirected: bool = False,
153-
num_nodes: Optional[Union[Tuple[int], int]] = None,
154-
training: bool = True) -> Tuple[Tensor, Tensor]:
158+
def add_random_edge(
159+
edge_index,
160+
p: float = 0.5,
161+
force_undirected: bool = False,
162+
num_nodes: Optional[Union[int, Tuple[int, int]]] = None,
163+
training: bool = True,
164+
) -> Tuple[Tensor, Tensor]:
155165
r"""Randomly adds edges to :obj:`edge_index`.
156166
157167
The method returns (1) the retained :obj:`edge_index`, (2) the added
@@ -160,6 +170,7 @@ def add_random_edge(edge_index, p: float, force_undirected: bool = False,
160170
Args:
161171
edge_index (LongTensor): The edge indices.
162172
p (float): Ratio of added edges to the existing edges.
173+
(default: :obj:`0.5`)
163174
force_undirected (bool, optional): If set to :obj:`True`,
164175
added edges will be undirected.
165176
(default: :obj:`False`)
@@ -208,30 +219,24 @@ def add_random_edge(edge_index, p: float, force_undirected: bool = False,
208219
[1, 3, 2]])
209220
"""
210221
if p < 0. or p > 1.:
211-
raise ValueError(f'Ratio of added edges has to be between 0 and 1 '
212-
f'(got {p}')
222+
raise ValueError(f"Ratio of added edges has to be between 0 and 1 "
223+
f"(got '{p}')")
213224
if force_undirected and isinstance(num_nodes, (tuple, list)):
214-
raise RuntimeError('`force_undirected` is not supported for'
215-
' heterogeneous graphs')
225+
raise RuntimeError("'force_undirected' is not supported for "
226+
"bipartite graphs")
216227

217228
device = edge_index.device
218229
if not training or p == 0.0:
219230
edge_index_to_add = torch.tensor([[], []], device=device)
220231
return edge_index, edge_index_to_add
221232

222-
if not isinstance(num_nodes, (tuple, list)):
223-
num_nodes = (num_nodes, num_nodes)
224-
num_src_nodes = maybe_num_nodes(edge_index, num_nodes[0])
225-
num_dst_nodes = maybe_num_nodes(edge_index, num_nodes[1])
226-
227-
num_edges_to_add = round(edge_index.size(1) * p)
228-
row = torch.randint(0, num_src_nodes, size=(num_edges_to_add, ))
229-
col = torch.randint(0, num_dst_nodes, size=(num_edges_to_add, ))
233+
edge_index_to_add = negative_sampling(
234+
edge_index=edge_index,
235+
num_nodes=num_nodes,
236+
num_neg_samples=round(edge_index.size(1) * p),
237+
force_undirected=force_undirected,
238+
)
230239

231-
if force_undirected:
232-
mask = row < col
233-
row, col = row[mask], col[mask]
234-
row, col = torch.cat([row, col]), torch.cat([col, row])
235-
edge_index_to_add = torch.stack([row, col], dim=0).to(device)
236240
edge_index = torch.cat([edge_index, edge_index_to_add], dim=1)
241+
237242
return edge_index, edge_index_to_add

torch_geometric/utils/negative_sampling.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from torch_geometric.utils.num_nodes import maybe_num_nodes
1010

1111

12-
def negative_sampling(edge_index: Tensor,
13-
num_nodes: Optional[Union[int, Tuple[int, int]]] = None,
14-
num_neg_samples: Optional[int] = None,
15-
method: str = "sparse",
16-
force_undirected: bool = False) -> Tensor:
12+
def negative_sampling(
13+
edge_index: Tensor,
14+
num_nodes: Optional[Union[int, Tuple[int, int]]] = None,
15+
num_neg_samples: Optional[int] = None,
16+
method: str = "sparse",
17+
force_undirected: bool = False,
18+
) -> Tensor:
1719
r"""Samples random negative edges of a graph given by :attr:`edge_index`.
1820
1921
Args:

0 commit comments

Comments
 (0)