Search code examples
pythonmatplotlibnetworkxgraph-visualizationdigraphs

Graph edge overlay when visualizing a networkx DAG using multipartite_layout


Consider the following snippet defining a DAG and drawing it:

import matplotlib.pyplot as plt
import networkx as nx

g = nx.DiGraph()
g.add_edge(1,2)
g.add_edge(2,3)
g.add_edge(3,4)
g.add_edge(1,4)
for layer, nodes in enumerate(nx.topological_generations(g)):
  for node in nodes:
    g.nodes[node]["layer"] = layer

plt.figure() 
pos = nx.multipartite_layout(g, subset_key="layer")
nx.draw_networkx(g, node_size=500, pos=pos, arrowsize=20, font_size=15, node_color='magenta')

It generates the image:

enter image description here

The node positions are the way I want them to be (topologically ordered left to right). However the edge (1 → 4) disappears in the picture. Apparently it's drawn over by the rest of the graph. I would like it to somehow bend around to be seen. That is, I want to see something like:

enter image description here

How can I achieve this?


Solution

  • You can use the connectionstyle parameter to tell networkx to draw curved edges, but first you'll want to separate your set of edges by which ones you want to draw straight and which ones you want to be curved.

    import matplotlib.pyplot as plt
    import networkx as nx
    
    g = nx.DiGraph()
    g.add_edge(1,2)
    g.add_edge(2,3)
    g.add_edge(3,4)
    g.add_edge(1,4)
    for layer, nodes in enumerate(nx.topological_generations(g)):
      for node in nodes:
        g.nodes[node]["layer"] = layer
    
    plt.figure()
    pos = nx.multipartite_layout(g, subset_key="layer")
    
    # Only one edge needs to be curved.
    curved_edges = [(1,4)]
    # All the rest can be straight.
    straight_edges = list(set(g.edges()) - set(curved_edges))
    nx.draw_networkx(g, node_size=500, pos=pos, arrowsize=20, font_size=15, node_color='magenta',
                     edgelist=straight_edges)
    # set connectionstyle to draw curved edge(s)
    nx.draw_networkx(g, node_size=500, pos=pos, arrowsize=20, font_size=15, node_color='magenta',
                     connectionstyle='arc3, rad=0.2', edgelist=curved_edges)
    plt.show()
    

    graph with a curved edge