Search code examples
machine-learningdeep-learningpytorchpytorch-geometric

What is the `node_dim` argument referring to in the message passing class?


In the PyTorch geometric tutorial for creating Message Passing Networks they have this paragraph at the start when explaining what the class does:

MessagePassing(aggr="add", flow="source_to_target", node_dim=-2): Defines the aggregation scheme to use ("add", "mean" or "max") and the flow direction of message passing (either "source_to_target" or "target_to_source"). Furthermore, the node_dim attribute indicates along which axis to propagate.

I don't understand what this node_dim is referring to, and why it is -2. I have looked at the documentation for the MessagePassing class and it says there that it is the axis which to propagate -- this still doesn't really clarify what we are doing here and why the default is -2 (presumably that is how you propagate information at a node level). Could someone offer some explanation of this to me please?


Solution

  • After referring to here and here, I think the thing related to it is the output of the 'message' function.
    In most cases, the shape of the output is [edge_num, emb_out], and if we set the node_dim as -2, it means that we will aggregate along the edge_num using indices of the target nodes.
    This is exactly the process that aggregates the information from source nodes.
    The result after aggregation is [node_num, emb_out].