Skip to content

Commit 6b0dd88

Browse files
authored
Merge branch 'master' into type_hints/ToDense
2 parents 648f202 + 888351b commit 6b0dd88

File tree

9 files changed

+72
-30
lines changed

9 files changed

+72
-30
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4040
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
4141
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
4242
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
43-
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668))
43+
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668))
4444
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
4545
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
4646
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))

test/utils/test_augmentation.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,21 @@ def test_mask_feature():
5454

5555
torch.manual_seed(7)
5656
out = mask_feature(x, mode='all')
57-
assert out[0].tolist() == [[1.0, 0.0, 0.0, 4.0], [5.0, 6.0, 7.0, 0.0],
58-
[9.0, 10.0, 0.0, 12.0]]
57+
assert out[0].tolist() == [[0.0, 2.0, 3.0, 0.0], [5.0, 0.0, 0.0, 8.0],
58+
[0.0, 0.0, 0.0, 0.0]]
5959

60-
assert out[1].tolist() == [[True, False, False, True],
61-
[True, True, True, False],
62-
[True, True, False, True]]
60+
assert out[1].tolist() == [[False, True, True, False],
61+
[True, False, False, True],
62+
[False, False, False, False]]
6363

6464
torch.manual_seed(7)
6565
out = mask_feature(x, mode='all', fill_value=-1)
66-
assert out[0].tolist() == [[1.0, -1., -1., 4.0], [5.0, 6.0, 7.0, -1.],
67-
[9.0, 10.0, -1., 12.0]]
68-
assert out[1].tolist() == [[True, False, False, True],
69-
[True, True, True, False],
70-
[True, True, False, True]]
66+
assert out[0].tolist() == [[-1.0, 2.0, 3.0, -1.0], [5.0, -1.0, -1.0, 8.0],
67+
[-1.0, -1.0, -1.0, -1.0]]
68+
69+
assert out[1].tolist() == [[False, True, True, False],
70+
[True, False, False, True],
71+
[False, False, False, False]]
7172

7273

7374
def test_add_random_edge():

test/utils/test_isolated.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def test_remove_isolated_nodes():
2828
assert out.tolist() == [[0, 1, 0], [1, 0, 0]]
2929
assert mask.tolist() == [1, 1]
3030

31+
if is_full_test():
32+
jit = torch.jit.script(remove_isolated_nodes)
33+
out, _, mask = jit(edge_index)
34+
assert out.tolist() == [[0, 1, 0], [1, 0, 0]]
35+
assert mask.tolist() == [1, 1]
36+
3137
out, _, mask = remove_isolated_nodes(edge_index, num_nodes=3)
3238
assert out.tolist() == [[0, 1, 0], [1, 0, 0]]
3339
assert mask.tolist() == [1, 1, 0]

test/utils/test_normalized_cut.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22

3+
from torch_geometric.testing import is_full_test
34
from torch_geometric.utils import normalized_cut
45

56

@@ -11,3 +12,8 @@ def test_normalized_cut():
1112

1213
output = normalized_cut(torch.stack([row, col], dim=0), edge_attr)
1314
assert output.tolist() == expected_output
15+
16+
if is_full_test():
17+
jit = torch.jit.script(normalized_cut)
18+
output = jit(torch.stack([row, col], dim=0), edge_attr)
19+
assert output.tolist() == expected_output

torch_geometric/datasets/icews.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,33 @@
1+
from typing import Callable, List, Optional
2+
13
import torch
24

35
from torch_geometric.data import Data, InMemoryDataset, download_url
46
from torch_geometric.io import read_txt_array
57

68

79
class EventDataset(InMemoryDataset):
8-
def __init__(self, root, transform=None, pre_transform=None,
9-
pre_filter=None):
10+
def __init__(
11+
self,
12+
root: str,
13+
transform: Optional[Callable] = None,
14+
pre_transform: Optional[Callable] = None,
15+
pre_filter: Optional[Callable] = None,
16+
):
1017
super().__init__(root, transform, pre_transform, pre_filter)
1118

1219
@property
13-
def num_nodes(self):
20+
def num_nodes(self) -> int:
1421
raise NotImplementedError
1522

1623
@property
17-
def num_rels(self):
24+
def num_rels(self) -> int:
1825
raise NotImplementedError
1926

20-
def process_events(self):
27+
def process_events(self) -> int:
2128
raise NotImplementedError
2229

23-
def process(self):
30+
def process(self) -> List[Data]:
2431
events = self.process_events()
2532
events = events - events.min(dim=0, keepdim=True)[0]
2633

@@ -64,34 +71,40 @@ class ICEWS18(EventDataset):
6471
url = 'https://github.com/INK-USC/RE-Net/raw/master/data/ICEWS18'
6572
splits = [0, 373018, 419013, 468558] # Train/Val/Test splits.
6673

67-
def __init__(self, root, split='train', transform=None, pre_transform=None,
68-
pre_filter=None):
74+
def __init__(
75+
self,
76+
root: str,
77+
split: str = 'train',
78+
transform: Optional[Callable] = None,
79+
pre_transform: Optional[Callable] = None,
80+
pre_filter: Optional[Callable] = None,
81+
):
6982
assert split in ['train', 'val', 'test']
7083
super().__init__(root, transform, pre_transform, pre_filter)
7184
idx = self.processed_file_names.index(f'{split}.pt')
7285
self.data, self.slices = torch.load(self.processed_paths[idx])
7386

7487
@property
75-
def num_nodes(self):
88+
def num_nodes(self) -> int:
7689
return 23033
7790

7891
@property
79-
def num_rels(self):
92+
def num_rels(self) -> int:
8093
return 256
8194

8295
@property
83-
def raw_file_names(self):
96+
def raw_file_names(self) -> List[str]:
8497
return [f'{name}.txt' for name in ['train', 'valid', 'test']]
8598

8699
@property
87-
def processed_file_names(self):
100+
def processed_file_names(self) -> List[str]:
88101
return ['train.pt', 'val.pt', 'test.pt']
89102

90103
def download(self):
91104
for filename in self.raw_file_names:
92105
download_url(f'{self.url}/{filename}', self.raw_dir)
93106

94-
def process_events(self):
107+
def process_events(self) -> torch.Tensor:
95108
events = []
96109
for path in self.raw_paths:
97110
data = read_txt_array(path, sep='\t', end=4, dtype=torch.long)

torch_geometric/transforms/one_hot_degree.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33

4+
from torch_geometric.data import Data
45
from torch_geometric.data.datapipes import functional_transform
56
from torch_geometric.transforms import BaseTransform
67
from torch_geometric.utils import degree
@@ -19,12 +20,17 @@ class OneHotDegree(BaseTransform):
1920
cat (bool, optional): Concat node degrees to node features instead
2021
of replacing them. (default: :obj:`True`)
2122
"""
22-
def __init__(self, max_degree, in_degree=False, cat=True):
23+
def __init__(
24+
self,
25+
max_degree: int,
26+
in_degree: bool = False,
27+
cat: bool = True,
28+
):
2329
self.max_degree = max_degree
2430
self.in_degree = in_degree
2531
self.cat = cat
2632

27-
def __call__(self, data):
33+
def __call__(self, data: Data) -> Data:
2834
idx, x = data.edge_index[1 if self.in_degree else 0], data.x
2935
deg = degree(idx, data.num_nodes, dtype=torch.long)
3036
deg = F.one_hot(deg, num_classes=self.max_degree + 1).to(torch.float)

torch_geometric/utils/augmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def mask_feature(x: Tensor, p: float = 0.5, mode: str = 'col',
143143
mask = torch.rand(x.size(1), device=x.device) >= p
144144
mask = mask.view(1, -1)
145145
else:
146-
mask = x.bernoulli(1 - p).to(torch.bool)
146+
mask = torch.randn_like(x) >= p
147147

148148
x = x.masked_fill(~mask, fill_value)
149149
return x, mask

torch_geometric/utils/isolated.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Tuple
22

33
import torch
44
from torch import Tensor
@@ -37,7 +37,11 @@ def contains_isolated_nodes(
3737
return torch.unique(edge_index.view(-1)).numel() < num_nodes
3838

3939

40-
def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None):
40+
def remove_isolated_nodes(
41+
edge_index: Tensor,
42+
edge_attr: Optional[Tensor] = None,
43+
num_nodes: Optional[int] = None,
44+
) -> Tuple[Tensor, Optional[Tensor], Tensor]:
4145
r"""Removes the isolated nodes from the graph given by :attr:`edge_index`
4246
with optional edge attributes :attr:`edge_attr`.
4347
In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter

torch_geometric/utils/normalized_cut.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from typing import Optional
22

3+
from torch import Tensor
4+
35
from torch_geometric.utils import degree
46

57

6-
def normalized_cut(edge_index, edge_attr, num_nodes: Optional[int] = None):
8+
def normalized_cut(
9+
edge_index: Tensor,
10+
edge_attr: Tensor,
11+
num_nodes: Optional[int] = None,
12+
) -> Tensor:
713
r"""Computes the normalized cut :math:`\mathbf{e}_{i,j} \cdot
814
\left( \frac{1}{\deg(i)} + \frac{1}{\deg(j)} \right)` of a weighted graph
915
given by edge indices and edge attributes.

0 commit comments

Comments
 (0)