Search code examples
pythonmachine-learningpytorchdata-sciencepytorch-geometric

Pytorch geometric: how to explain the input in the below code-snippet?


I am reading PyTorch geometric documentation at https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html

On this page, there is a code snippet:

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

The output of the last line from the above code snippet is:

Data(edge_index=[2, 4], x=[3, 1])

How are the edge_index 2 and 4? If I understand correctly there are four edges being defined with an index starting from 0. Is this assumption wrong? Also, what does x =[3, 1] mean?

Data is a class, so I won't expect it to return anything. Class definition is here: https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html . I read the documentation. x should be Node feature matrix and edge_index should be graph connectivity. But I don't understand the console output that I cross-checked in the jupyter notebook.


Solution

  • Okay, I think I have got the understanding of the output Data(edge_index=[2, 4], x=[3, 1]). Here [2,4] are dimensions of edge_index and [3,1] are dimensions of x. But please, anyone correct me if I am wrong.