Search code examples
pytorchnetworkxpytorch-geometric

Pytorch Geometric: How do I pass the 2nd or 3rd arguments to from_networkx?


I am trying to use from_networkx() from Pytorch Geometric. I have a networkx Graph object as my first argument, and am trying to feed in a list of strings for the node attribute. I am getting an error that I am giving it 2 positional arguments when it wants 1. How can I make this code function or find a workaround?

The first line below is a list of attributes produced by nx.get_attributes(I, 'spin').

{(0, 0): 1, (0, 1): 1, (0, 2): -1, (0, 3): 1, (1, 0): 1, (1, 1): 1, (1, 2): 1, (1, 3): 1, (2, 0): 1, (2, 1): -1, (2, 2): -1, (2, 3): 1, (3, 0): -1, (3, 1): -1, (3, 2): 1, (3, 3): -1}
Graph with 16 nodes and 32 edges
<class 'networkx.classes.graph.Graph'>
Traceback (most recent call last):
  File "pytorch_test.py", line 222, in <module>
    print(from_networkx(I, ["spin"]))
TypeError: from_networkx() takes 1 positional argument but 2 were given

Solution

  • I guess you are running pytorch_geometric with version <= 1.7.2. Then the method from_networkx had only one parameter. Only the latest from_networkx has the additional parameters.

    However, before the additional parameters were introduced-and still the default behaviour- all node attributes were transformed:

    import networkx as nx
    
    g = nx.karate_club_graph()
    print(g.nodes(data=True))
    # [(0, {'club': 'Mr. Hi'}), (1, {'club': 'Mr. Hi'}), (2, {'club': 'Mr. Hi'}), ....
    import torch
    from torch_geometric.utils import from_networkx
    
    data = from_networkx(g)
    
    print(data)
    #Data(club=[34], edge_index=[2, 156])
    

    So in your example, data.spin should work if you use from_networkx(I).