Skip to content

SetTransformerAggregation returns nan for an unconnected node. #7899

Closed
@ATheCoder

Description

@ATheCoder

🐛 Describe the bug

When you use message passing with a SetTransformerAggregation and the input graph includes any number of nodes that are disconnected from the rest of the graph, the SetTransformerAggregation returns nan for those nodes. This is in contrast to the SumAggregation which returns plain 0.

from torch import Tensor
import torch
from torch_geometric.nn import MessagePassing, SetTransformerAggregation
from torch_geometric.data import Data, Batch
from torch_geometric.utils import sort_edge_index

class MPNN4Set(MessagePassing):
    def __init__(self, dim, n_heads):
        super(MPNN4Set, self).__init__()
        self.dim = dim
        self.aggregator = SetTransformerAggregation(dim, heads=n_heads)

    def forward(self, h, edge_index, batch):
        edge_index = sort_edge_index(edge_index, sort_by_row=False)
        h = self.propagate(edge_index, x=h, num_nodes=h.size(0), batch=batch)

        return h

    def message(self, x_i, x_j, edge_index, num_nodes, batch):
        return x_j
    
    def aggregate(self, inputs: Tensor, index: Tensor, ptr: Tensor | None = None, dim_size: int | None = None) -> Tensor:
        h = self.aggregator(inputs, index, ptr, dim_size)
        return h

    def update(self, aggr_out, batch):
        return aggr_out
    
    
m = MPNN4Set(10, 2)
graphs = [Data(x=torch.randn((3, 10)), edge_index=torch.tensor([[0, 1], [1, 0]], dtype=torch.long)), Data(x=torch.randn((3, 10)), edge_index=torch.tensor([[0, 1, 2], [2, 1, 0]], dtype=torch.long))]

batched_graphs = Batch.from_data_list(graphs)

res = m(batched_graphs.x, batched_graphs.edge_index, batched_graphs.batch)

assert res[2].isnan().any().item() is True

I managed to debug this a little bit and it seems like this stems from the fact that in PyTorch's MultiHeadAttention implementation you shouldn't mask a row completely:

import torch
from torch.nn import functional as F
from torch import nn

m = nn.MultiheadAttention(10, 2)

t1 = torch.randn((3, 3, 10))
mask = torch.tensor([[True, True, True], [False, False, False], [False, False, False]])

m(t1, t1, t1, mask) # Includes nan

This happens because the unbatch function will mask the row corresponding to that node because it is not connected to any other node.

Environment

  • PyG version: 2.3.1
  • PyTorch version: 2.1.0a0+b5021ba
  • OS: Ubuntu 22.04
  • Python version: 3.10.6
  • CUDA/cuDNN version: 12.2
  • How you installed PyTorch and PyG (conda, pip, source): pip

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions