Search code examples
pythonmachine-learningpytorchneural-networkconv-neural-network

pytorch gcn layer forward pass produces unexpected output features


My written short script of a gcn layer forward pass does not produce expected values. The script contains a graph with 3 nodes with the names 0,1,2. The edges set of the graph is {{0,1},{1,2}}. The 1-dimensional input feature for node 0 is the value 10.0, for node 1 is the value 20.0 and for node 2 is the value 30.0.

import torch
from torch_geometric.nn import GCNConv

x = torch.tensor([[10.0], [20.0], [30.0]], dtype=torch.float)
edge_index = torch.tensor([[0,1], [1,2]], dtype=torch.long)

new_lin_weight_values = torch.tensor([[1.0]])
new_lin_bias_values = torch.tensor([[0.0]])

conv_layer = GCNConv(in_channels=1, out_channels=1)

# Set the weight and bias values directly as parameters
conv_layer.lin.weight = torch.nn.Parameter(new_lin_weight_values)
conv_layer.lin.bias = torch.nn.Parameter(new_lin_bias_values)

# Disable the learning of parameters by freezing them
conv_layer.lin.weight.requires_grad = False
conv_layer.lin.bias.requires_grad = False

output = conv_layer.forward(x, edge_index)

print("Input Features:")
print(x)
print("Output Features:")
print(output)

Since a gcn layer propagation for a node is to take the neighbors input feature representation and the representation of itself to compute the new representation, I would expect that the output feature would look something like the output value 15.0 for node 0, value 20.0 for node 1 and value 25.0 for node 2.

Though the script produces the output value 10.0 for node 0, 17.0711 for node 1 and 25.0 for node 2. The value for node 2 makes sense to me but especially the output feature representation for node 1 is inexplicable for me.

Why do I get these weird values?

I made sure to fix the weight matrix, bias and to make the algorithm determinant.


Solution

  • I believe that you confuse Graph Convolutional Network and some basic Message Passing Network.

    In GCN, as they were originally introduced, messages have special normalisation factor proportional to a square root of product of degrees between sender and receiver.

    Furthermore make sure to consider that edge_index represents directed edges and you have to add both directions if you intend to work on bidirectional graphs.

    You can read in the original documentation https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html enter image description here