Search code examples
pythonmatplotlibnetworkxgraph-theoryedges

MultiDiGraph edges from networkx draw with connectionStyle


Is it possible to somehow draw different edges at the same nodes with different curvatures using connectionstyle?

I wrote the following code, but I got all three edges overlapped:

import networkx as nx
import matplotlib.pyplot as plt

G = nx.MultiDiGraph()
G.add_node('n1')
G.add_node('n2')
G.add_edge('n1', 'n2', 0)
G.add_edge('n1', 'n2', 1)
G.add_edge('n1', 'n2', 2)

pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, connectionstyle='arc3, rad = 0.3')

plt.show()

Solution

  • This can be done by plotting each edge with a different rad argument - as shown. Note that my approach here uses f-strings which require Python 3.6 - below that you will have to build the string using a different method.

    Code:

    import networkx as nx
    import matplotlib.pyplot as plt
    
    G = nx.MultiDiGraph()
    G.add_node('n1')
    G.add_node('n2')
    G.add_edge('n1', 'n2', rad=0.1)
    G.add_edge('n1', 'n2', rad=0.2)
    G.add_edge('n1', 'n2', rad=0.3)
    
    plt.figure(figsize=(6,6))
    
    pos = nx.spring_layout(G)
    nx.draw_networkx_nodes(G, pos)
    nx.draw_networkx_labels(G, pos)
    
    for edge in G.edges(data=True):
        nx.draw_networkx_edges(G, pos, edgelist=[(edge[0],edge[1])], connectionstyle=f'arc3, rad = {edge[2]["rad"]}')
    
    plt.show()
    

    Output:

    enter image description here

    We could even create a new function to do this for us:

    import networkx as nx
    import matplotlib.pyplot as plt
    
    def new_add_edge(G, a, b):
        if (a, b) in G.edges:
            max_rad = max(x[2]['rad'] for x in G.edges(data=True) if sorted(x[:2]) == sorted([a,b]))
        else:
            max_rad = 0
        G.add_edge(a, b, rad=max_rad+0.1)
    
    G = nx.MultiDiGraph()
    G.add_node('n1')
    G.add_node('n2')
    
    for i in range(5):
        new_add_edge(G, 'n1', 'n2')
    
    for i in range(5):
        new_add_edge(G, 'n2', 'n1')
    
    plt.figure(figsize=(6,6))
    
    pos = nx.spring_layout(G)
    nx.draw_networkx_nodes(G, pos)
    nx.draw_networkx_labels(G, pos)
    
    for edge in G.edges(data=True):
        nx.draw_networkx_edges(G, pos, edgelist=[(edge[0],edge[1])], connectionstyle=f'arc3, rad = {edge[2]["rad"]}')
    
    plt.show()
    

    Output:

    enter image description here