Skip to content

Commit 7029bad

Browse files
EdisonLeeeeerusty1sPadarn
authored
Add assortativity to torch_geometric.utils (#5587)
* add assortativity * test * doc-string * doc-string * Update torch_geometric/utils/assortativity.py Co-authored-by: Matthias Fey <[email protected]> * Update torch_geometric/utils/assortativity.py Co-authored-by: Matthias Fey <[email protected]> * Update torch_geometric/utils/assortativity.py Co-authored-by: Padarn Wilson <[email protected]> * update test * doc-string * fix test * changelog Co-authored-by: Matthias Fey <[email protected]> Co-authored-by: Padarn Wilson <[email protected]>
1 parent 6ca2332 commit 7029bad

File tree

4 files changed

+88
-1
lines changed

4 files changed

+88
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.2.0] - 2022-MM-DD
77
### Added
8+
- Added `assortativity` that computes degree assortativity coefficient ([#5587](https://github.com/pyg-team/pytorch_geometric/pull/5587))
89
- Added `SSGConv` layer ([#5599](https://github.com/pyg-team/pytorch_geometric/pull/5599))
910
- Added `shuffle_node`, `mask_feature` and `add_random_edge` augmentation methdos ([#5548](https://github.com/pyg-team/pytorch_geometric/pull/5548))
1011
- Added `dropout_path` augmentation that drops edges from a graph based on random walks ([#5531](https://github.com/pyg-team/pytorch_geometric/pull/5531))

test/utils/test_assortativity.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
import torch
3+
4+
from torch_geometric.utils import assortativity
5+
6+
7+
def test_assortativity():
8+
# completely assortative graph
9+
edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],
10+
[1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])
11+
out = assortativity(edge_index)
12+
assert pytest.approx(out, abs=1e-5) == 1.0
13+
14+
# completely disassortative graph
15+
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 5, 5, 5, 5],
16+
[5, 5, 5, 5, 5, 0, 1, 2, 3, 4]])
17+
out = assortativity(edge_index)
18+
assert pytest.approx(out, abs=1e-5) == -1.0

torch_geometric/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .degree import degree
22
from .softmax import softmax
33
from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path
4+
from .augmentation import shuffle_node, mask_feature, add_random_edge
45
from .sort_edge_index import sort_edge_index
56
from .coalesce import coalesce
67
from .undirected import is_undirected, to_undirected
@@ -11,6 +12,7 @@
1112
from .subgraph import (get_num_hops, subgraph, k_hop_subgraph,
1213
bipartite_subgraph)
1314
from .homophily import homophily
15+
from .assortativity import assortativity
1416
from .get_laplacian import get_laplacian
1517
from .get_mesh_laplacian import get_mesh_laplacian
1618
from .mask import index_to_mask, mask_to_index
@@ -34,7 +36,6 @@
3436
structured_negative_sampling_feasible)
3537
from .train_test_split_edges import train_test_split_edges
3638
from .scatter import scatter
37-
from .augmentation import shuffle_node, mask_feature, add_random_edge
3839

3940
__all__ = [
4041
'degree',
@@ -63,6 +64,7 @@
6364
'bipartite_subgraph',
6465
'k_hop_subgraph',
6566
'homophily',
67+
'assortativity',
6668
'get_laplacian',
6769
'get_mesh_laplacian',
6870
'index_to_mask',
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
from torch_sparse import SparseTensor
3+
4+
from torch_geometric.typing import Adj
5+
from torch_geometric.utils import coalesce, degree
6+
7+
from .to_dense_adj import to_dense_adj
8+
9+
10+
def assortativity(edge_index: Adj) -> float:
11+
r"""The degree assortativity coefficient from the
12+
`"Mixing patterns in networks"
13+
<https://arxiv.org/abs/cond-mat/0209450>`_ paper.
14+
Assortativity in a network refers to the tendency of nodes to
15+
connect with other similar nodes over dissimilar nodes.
16+
It is computed from Pearson correlation coefficient of the node degrees.
17+
18+
Args:
19+
edge_index (Tensor or SparseTensor): The graph connectivity.
20+
21+
Returns:
22+
The value of the degree assortativity coefficient for the input
23+
graph :math:`\in [-1, 1]`
24+
25+
Example:
26+
27+
>>> edge_index = torch.tensor([[0, 1, 2, 3, 2],
28+
... [1, 2, 0, 1, 3]])
29+
>>> assortativity(edge_index)
30+
-0.666667640209198
31+
"""
32+
if isinstance(edge_index, SparseTensor):
33+
row, col, _ = edge_index.coo()
34+
else:
35+
row, col = edge_index
36+
37+
device = row.device
38+
out_deg = degree(row, dtype=torch.long)
39+
in_deg = degree(col, dtype=torch.long)
40+
degrees = torch.unique(torch.cat([out_deg, in_deg]))
41+
mapping = row.new_zeros(degrees.max().item() + 1)
42+
mapping[degrees] = torch.arange(degrees.size(0), device=device)
43+
44+
# Compute degree mixing matrix (joint probability distribution) `M`
45+
num_degrees = degrees.size(0)
46+
src_deg = mapping[out_deg[row]]
47+
dst_deg = mapping[in_deg[col]]
48+
49+
pairs = torch.stack([src_deg, dst_deg], dim=0)
50+
occurrence = torch.ones(pairs.size(1), device=device)
51+
pairs, occurrence = coalesce(pairs, occurrence)
52+
M = to_dense_adj(pairs, edge_attr=occurrence, max_num_nodes=num_degrees)[0]
53+
# normalization
54+
M /= M.sum()
55+
56+
# numeric assortativity coefficient, computed by
57+
# Pearson correlation coefficient of the node degrees
58+
x = y = degrees.float()
59+
a, b = M.sum(0), M.sum(1)
60+
61+
vara = (a * x**2).sum() - ((a * x).sum())**2
62+
varb = (b * x**2).sum() - ((b * x).sum())**2
63+
xy = torch.outer(x, y)
64+
ab = torch.outer(a, b)
65+
out = (xy * (M - ab)).sum() / (vara * varb).sqrt()
66+
return out.item()

0 commit comments

Comments
 (0)