|
5 | 5 |
|
6 | 6 | from torch_geometric.typing import OptTensor
|
7 | 7 |
|
| 8 | +from .num_nodes import maybe_num_nodes |
| 9 | +from .subgraph import subgraph |
| 10 | + |
8 | 11 |
|
9 | 12 | def filter_adj(row: Tensor, col: Tensor, edge_attr: OptTensor,
|
10 | 13 | mask: Tensor) -> Tuple[Tensor, Tensor, OptTensor]:
|
@@ -79,3 +82,56 @@ def dropout_adj(
|
79 | 82 | edge_index = torch.stack([row, col], dim=0)
|
80 | 83 |
|
81 | 84 | return edge_index, edge_attr
|
| 85 | + |
| 86 | + |
| 87 | +def dropout_node(edge_index: Tensor, p: float = 0.5, |
| 88 | + num_nodes: Optional[int] = None, |
| 89 | + training: bool = True) -> Tuple[Tensor, Tensor, Tensor]: |
| 90 | + r"""Randomly drops nodes from the adjacency matrix |
| 91 | + :obj:`edge_index` with probability :obj:`p` using samples from |
| 92 | + a Bernoulli distribution. |
| 93 | +
|
| 94 | + The method returns (1) the retained :obj:`edge_index`, (2) the edge mask |
| 95 | + indicating which edges were dropped. (3) the node mask indicating |
| 96 | + which nodes were dropped. |
| 97 | +
|
| 98 | + Args: |
| 99 | + edge_index (LongTensor): The edge indices. |
| 100 | + p (float, optional): Dropout probability. (default: :obj:`0.5`) |
| 101 | + num_nodes (int, optional): The number of nodes, *i.e.* |
| 102 | + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) |
| 103 | + training (bool, optional): If set to :obj:`False`, this operation is a |
| 104 | + no-op. (default: :obj:`True`) |
| 105 | +
|
| 106 | + :rtype: (:class:`LongTensor`, :class:`BoolTensor`, :class:`BoolTensor`) |
| 107 | +
|
| 108 | + Examples: |
| 109 | +
|
| 110 | + >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], |
| 111 | + ... [1, 0, 2, 1, 3, 2]]) |
| 112 | + >>> edge_index, edge_mask, node_mask = dropout_node(edge_index) |
| 113 | + >>> edge_index |
| 114 | + tensor([[0, 1], |
| 115 | + [1, 0]]) |
| 116 | + >>> edge_mask |
| 117 | + tensor([ True, True, False, False, False, False]) |
| 118 | + >>> node_mask |
| 119 | + tensor([ True, True, False, False]) |
| 120 | + """ |
| 121 | + if p < 0. or p > 1.: |
| 122 | + raise ValueError(f'Dropout probability has to be between 0 and 1 ' |
| 123 | + f'(got {p}') |
| 124 | + |
| 125 | + num_nodes = maybe_num_nodes(edge_index, num_nodes) |
| 126 | + |
| 127 | + if not training or p == 0.0: |
| 128 | + node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool) |
| 129 | + edge_mask = edge_index.new_zeros(edge_index.size(1), dtype=torch.bool) |
| 130 | + return edge_index, edge_mask, node_mask |
| 131 | + |
| 132 | + prob = torch.rand(num_nodes, device=edge_index.device) |
| 133 | + node_mask = prob > p |
| 134 | + edge_index, _, edge_mask = subgraph(node_mask, edge_index, |
| 135 | + num_nodes=num_nodes, |
| 136 | + return_edge_mask=True) |
| 137 | + return edge_index, edge_mask, node_mask |
0 commit comments