Search code examples
pythonmatplotlibnetworkx

python networkx - how to draw graph with varying edge width


Based on this code snippet I tried to create a graph with varying edge width. I have the following data for a graph representing a 4x5 grid with 20 nodes and only up,down,left,right connections:

import numpy as np
weights = np.array([1.1817, 1.5336, 1.1325, 0.9202, 1.5881, 1.7083, 0.4012, 0.5972, 0.4937,
        1.1593, 1.2978, 0.0218, 0.1328, 1.9135, 1.2934, 0.2250, 0.5520, 1.3033,
        0.1133, 1.6854, 1.9010, 1.9293, 1.8916, 1.5798, 1.6423, 0.0683, 0.1891,
        0.6299, 0.2556, 0.7484, 1.8622])
edge_index = [[ 0,  1],
        [ 1,  2],
        [ 2,  3],
        [ 3,  4],
        [ 0,  5],
        [ 1,  6],
        [ 2,  7],
        [ 3,  8],
        [ 4,  9],
        [ 5,  6],
        [ 6,  7],
        [ 7,  8],
        [ 8,  9],
        [ 5, 10],
        [ 6, 11],
        [ 7, 12],
        [ 8, 13],
        [ 9, 14],
        [10, 11],
        [11, 12],
        [12, 13],
        [13, 14],
        [10, 15],
        [11, 16],
        [12, 17],
        [13, 18],
        [14, 19],
        [15, 16],
        [16, 17],
        [17, 18],
        [18, 19],
        [ 1,  0],
        [ 2,  1],
        [ 3,  2],
        [ 4,  3],
        [ 5,  0],
        [ 6,  1],
        [ 7,  2],
        [ 8,  3],
        [ 9,  4],
        [ 6,  5],
        [ 7,  6],
        [ 8,  7],
        [ 9,  8],
        [10,  5],
        [11,  6],
        [12,  7],
        [13,  8],
        [14,  9],
        [11, 10],
        [12, 11],
        [13, 12],
        [14, 13],
        [15, 10],
        [16, 11],
        [17, 12],
        [18, 13],
        [19, 14],
        [16, 15],
        [17, 16],
        [18, 17],
        [19, 18]]

The order of the weights is the same as the edges provided in edge_index.

I wrote the following code to visualize the nodes and their connections:

from itertools import product
import networkx as nx
import matplotlib.pyplot as plt

G = nx.Graph()
dx = 4 # spacing
# create nodes
for nidx, (ridx, cidx) in enumerate(product(range(4), range(5))):
    #print(ridx,cidx)
    G.add_node(nidx, pos=(dx*cidx, -dx*ridx) )

# create 31 edges
for gidx, w in zip(edge_index, weights):
    #print(gidx, w)
    G.add_edge(*gidx, weight=w)

    
pos=nx.get_node_attributes(G,'pos')
labels = {k:f"{v:.3f}" for k, v in nx.get_edge_attributes(G, 'weight').items()}
nx.draw(G, pos)
nx.draw_networkx_labels(G, pos=pos, font_color='w')
nx.draw_networkx_edges(G, pos, width=10*weights)
nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
plt.show()

The result looks as follows: enter image description here

It is almost the result as I want it, however I don't know why the edge (7,8) is so big even though the weight is relatively small (i.e. compared to edge (2,3)). Reversely the edge (6,7) is way smaller than then edge (5,6). Is this a bug? Or am I doing something wrong? I double checked the ordering of the weight array but couldn't find a mistake. Any help is appreciated!


Solution

  • Not sure what exactly is happening when one calls nx.draw_networkx_edges() However according to the docs one can explicitly provide a edgelist argument. Thus providing the edgelist in the same order as weights fixed the issue:

    
    G = nx.Graph()
    dx = 4 # spacing
    # create nodes
    for nidx, (ridx, cidx) in enumerate(product(range(4), range(5))):
        #print(ridx,cidx)
        G.add_node(nidx, pos=(dx*cidx, -dx*ridx) )
    
    # create edges
    nedges = 31
    for gidx, w in zip(edge_index, weights):
        #print(gidx, w)
        G.add_edge(*gidx, weight=w)
    
        
    pos=nx.get_node_attributes(G,'pos')
    labels = {k:f"{v:.3f}" for k, v in nx.get_edge_attributes(G, 'weight').items()}
    nx.draw(G, pos)
    nx.draw_networkx_labels(G, pos=pos, font_color='w')
    #######################################################################
    nx.draw_networkx_edges(G, pos, edgelist=edge_index, width=10*weights) # <--- changed this line 
    #######################################################################
    nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)
    plt.show()
    

    enter image description here