Skip to content

Commit 52ce0e4

Browse files
[Type Hints] datasets.Reddit and datasets.Reddit2 (#5695)
Co-authored-by: Matthias Fey <[email protected]>
1 parent d81176d commit 52ce0e4

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
4343
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
4444
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
45-
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688))
45+
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695))
4646
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
4747
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
4848
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))

torch_geometric/datasets/reddit.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import os.path as osp
3+
from typing import Callable, List, Optional
34

45
import numpy as np
56
import scipy.sparse as sp
@@ -47,16 +48,21 @@ class Reddit(InMemoryDataset):
4748

4849
url = 'https://data.dgl.ai/dataset/reddit.zip'
4950

50-
def __init__(self, root, transform=None, pre_transform=None):
51+
def __init__(
52+
self,
53+
root: str,
54+
transform: Optional[Callable] = None,
55+
pre_transform: Optional[Callable] = None,
56+
):
5157
super().__init__(root, transform, pre_transform)
5258
self.data, self.slices = torch.load(self.processed_paths[0])
5359

5460
@property
55-
def raw_file_names(self):
61+
def raw_file_names(self) -> List[str]:
5662
return ['reddit_data.npz', 'reddit_graph.npz']
5763

5864
@property
59-
def processed_file_names(self):
65+
def processed_file_names(self) -> str:
6066
return 'data.pt'
6167

6268
def download(self):

torch_geometric/datasets/reddit2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
import os.path as osp
4+
from typing import Callable, List, Optional
45

56
import numpy as np
67
import scipy.sparse as sp
@@ -54,16 +55,21 @@ class Reddit2(InMemoryDataset):
5455
class_map_id = '1JF3Pjv9OboMNYs2aXRQGbJbc4t_nDd5u'
5556
role_id = '1nJIKd77lcAGU4j-kVNx_AIGEkveIKz3A'
5657

57-
def __init__(self, root, transform=None, pre_transform=None):
58+
def __init__(
59+
self,
60+
root: str,
61+
transform: Optional[Callable] = None,
62+
pre_transform: Optional[Callable] = None,
63+
):
5864
super().__init__(root, transform, pre_transform)
5965
self.data, self.slices = torch.load(self.processed_paths[0])
6066

6167
@property
62-
def raw_file_names(self):
68+
def raw_file_names(self) -> List[str]:
6369
return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json']
6470

6571
@property
66-
def processed_file_names(self):
72+
def processed_file_names(self) -> str:
6773
return 'data.pt'
6874

6975
def download(self):

0 commit comments

Comments
 (0)