Search code examples
pythonpandasnetworkx

Construct NetworkX graph from Pandas DataFrame


I'd like to create some NetworkX graphs from a simple Pandas DataFrame:

        Loc 1   Loc 2   Loc 3   Loc 4   Loc 5   Loc 6   Loc 7
Foo     0       0       1       1       0       0           0
Bar     0       0       1       1       0       1           1
Baz     0       0       1       0       0       0           0
Bat     0       0       1       0       0       1           0
Quux    1       0       0       0       0       0           0

Where Foo… is the index, and Loc 1 to Loc 7 are the columns. But converting to Numpy matrices or recarrays doesn't seem to work for generating input for nx.Graph(). Is there a standard strategy for achieving this? I'm not averse the reformatting the data in Pandas --> dumping to CSV --> importing to NetworkX, but it seems as if I should be able to generate the edges from the index and the nodes from the values.


Solution

  • NetworkX expects a square matrix (of nodes and edges), perhaps* you want to pass it:

    In [11]: df2 = pd.concat([df, df.T]).fillna(0)
    

    Note: It's important that the index and columns are in the same order!

    In [12]: df2 = df2.reindex(df2.columns)
    
    In [13]: df2
    Out[13]: 
           Bar  Bat  Baz  Foo  Loc 1  Loc 2  Loc 3  Loc 4  Loc 5  Loc 6  Loc 7  Quux
    Bar      0    0    0    0      0      0      1      1      0      1      1     0
    Bat      0    0    0    0      0      0      1      0      0      1      0     0
    Baz      0    0    0    0      0      0      1      0      0      0      0     0
    Foo      0    0    0    0      0      0      1      1      0      0      0     0
    Loc 1    0    0    0    0      0      0      0      0      0      0      0     1
    Loc 2    0    0    0    0      0      0      0      0      0      0      0     0
    Loc 3    1    1    1    1      0      0      0      0      0      0      0     0
    Loc 4    1    0    0    1      0      0      0      0      0      0      0     0
    Loc 5    0    0    0    0      0      0      0      0      0      0      0     0
    Loc 6    1    1    0    0      0      0      0      0      0      0      0     0
    Loc 7    1    0    0    0      0      0      0      0      0      0      0     0
    Quux     0    0    0    0      1      0      0      0      0      0      0     0
    
    In[14]: graph = nx.from_numpy_matrix(df2.values)
    

    This doesn't pass the column/index names to the graph, if you wanted to do that you could use relabel_nodes (you may have to be wary of duplicates, which are allowed in pandas' DataFrames):

    In [15]: graph = nx.relabel_nodes(graph, dict(enumerate(df2.columns))) # is there nicer  way than dict . enumerate ?
    

    *It's unclear exactly what the columns and index represent for the desired graph.