In the PyTorch geometric tutorial for creating Message Passing Networks they have this paragraph at the start when explaining what the class does:
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
: Defines the aggregation scheme to use("add", "mean" or "max")
and the flow direction of message passing (either"source_to_target"
or"target_to_source"
). Furthermore, thenode_dim
attribute indicates along which axis to propagate.
I don't understand what this node_dim
is referring to, and why it is -2. I have looked at the documentation for the MessagePassing
class and it says there that it is the axis which to propagate -- this still doesn't really clarify what we are doing here and why the default is -2 (presumably that is how you propagate information at a node level). Could someone offer some explanation of this to me please?
After referring to here and here, I think the thing related to it is the output of the 'message' function.
In most cases, the shape of the output is [edge_num, emb_out]
, and if we set the node_dim
as -2, it means that we will aggregate along the edge_num
using indices of the target nodes.
This is exactly the process that aggregates the information from source nodes.
The result after aggregation is [node_num, emb_out]
.