Search code examples
pytorchbatchingpytorch-geometric

How to batch a nested list of graphs in pytorch geometric


I am currently training a model which is a mix of graph neural networks and LSTM. However that means for each of my training sample, I need to pass in a list of graphs. The current batch class in torch_geometric supports batching with torch_geometric.data.Batch.from_data_list() but this only allows one graph for each data point. How else can I go about batching the graphs?


Solution

  • Use diagonal batching:

    https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html

    Simply, you will put all the graphs as subgraphs into one big graph. All the subgraphs will be isolated.

    See the example from TUDataset: https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing