Search code examples
pandasdataframenetworkx

construct a networkx graph from DataFrame


the data looks like enter image description here

I need to construct a directed graph where each row of the dataframe corresponds to a node in the graph, and an edge is drawn between nodes if the two nodes' weight > 0.2. How can this be done?

This my code:

G=nx.DiGraph() 

nodes = [node for node in ]
G.add_nodes_from(nodes)

edges = [edge for edge in ]

pos = nx.spring_layout(G)
nx.draw_networkx_nodes(G, pos, node_size=100)
nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif')
edges = G.edges()
weights = [G[u][v]['weight'] for u,v in edges]
nx.draw_networkx_edges(G, pos, edgelist=edges, width=weights)
plt.axis('off')
plt.show()

I have added some nodes, but couldn't add edges.


Solution

  • Update

    You can also use from_pandas_adjacency:

    adj = df.copy()
    adj.values[np.diag_indices_from(adj)] = 0  # reset diagonal values 1 -> 0
    adj[adj <= thresh] = 0  # reset edges with weight <= thresh
    G = nx.from_pandas_adjacency(adj, create_using=nx.DiGraph)
    

    You have to unstack your dataframe first:

    thresh = 0.2
    
    edges = (df.stack().rename_axis(['source', 'target']).rename('weight')
               .reset_index().query('(source != target) & (weight > @thresh)'))
    print(edges)
    
    # Output
       source target    weight
    2       A      C  0.423134
    5       A      F  0.332323
    9       B      D  0.314762
    10      B      E  0.734024
    11      B      F  0.510780
    12      C      A  0.423134
    16      C      E  0.324150
    17      C      F  0.635273
    19      D      B  0.314762
    25      E      B  0.734024
    26      E      C  0.324150
    29      E      F  0.275080
    30      F      A  0.332323
    31      F      B  0.510780
    32      F      C  0.635273
    34      F      E  0.275080
    

    Now you can transform this dataframe into a network graph with nx.from_pandas_edgelist:

    G = nx.from_pandas_edgelist(edges, source='source', target='target',
                                edge_attr='weight', create_using=nx.DiGraph)
    

    Your code:

    pos = nx.spring_layout(G)
    nx.draw_networkx_nodes(G, pos, node_size=100)
    nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif')
    edges = G.edges()
    weights = [G[u][v]['weight'] for u,v in edges]
    nx.draw_networkx_edges(G, pos, edgelist=edges, width=weights)
    plt.axis('off')
    plt.show()
    

    Output:

    enter image description here

    Minimal Reproducible Example:

    >>> df
              A         B         C         D         E         F
    A  1.000000  0.090827  0.423134  0.047234  0.000000  0.332323
    B  0.090827  0.156203  0.000000  0.314762  0.734024  0.510780
    C  0.423134  0.000000  0.000000  0.000000  0.324150  0.635273
    D  0.047234  0.314762  0.000000  0.322648  0.095448  0.124856
    E  0.000000  0.734024  0.324150  0.095448  0.000000  0.275080
    F  0.332323  0.510780  0.635273  0.124856  0.275080  1.000000