Search code examples
pythongraphnetworkx

Draw a graph with node size and edge width proportional to number of repetitions


I'm working on a paper that is similar to this. I have to draw a directed graph that looks similar to the one presented in the paper.

I work primarily in Python, and have explored various options in matplotlib, seaborn and networkx to draw a similar graph, but could not work it out. The input of nodes in my case are a few list of binary numbers of fixed length for a graph. Here I am giving example of binary number of length 5, but the length can go upto 15. So you can imagine the number of all possible nodes can become really large!

For example, the lists that represent the paths may look like:

all_possible = [f"{i:>05b}" for i in range(2**5)] # Code to generate all possible states
# Paths taken by the graph
path1 = ['start', '01101', '00001', '11000', '11100', '00010', '00100', 'end']
path2 = ['start', '10001', '00111', '01100', '01110', '10000', '11000', '00101', '11011', '01010', 'end']
path3 = ['start', '11100', '01001', '01100', '00011', '10111', '10000', '00001', 'end']
path4 = ['start', '01100', '01110', '11111', '11011', '10001', '11011', '11101', '00000', 'end']
path5 = ['start', '10001', '11100', '01101', '01001', '01000', '00101', '11001', '00101', '11100', 'end']

You can imagine the paths all starting from start state and ending at end state. In between they go via various combinations of the possible states. For path1 it goes like start -> 01101 -> 00001 -> 11000 -> 11100 -> 00010 -> 00100 -> end

Now as more and more edges will merge to, for example node 11100, it should become larger or at least change color to indicate a larger number. Same goes for edges, i.e. if same edge is traversed multiple times, for example the edge 01100 -> 01110 should be wider to represent the number of times it was traversed.

So far I have tried various solutions, but the closest one would be with networkx module:

def draw_complete_graph(all_graph_edges: list, fig_size: tuple = (50, 50), graph_img_name: str = "test_graph"):
    G = nx.MultiDiGraph()
    plt.figure(figsize=fig_size)
    for single_edge in all_graph_edges:
        G.add_edge(
            single_edge[0],
            single_edge[-1],
            weight=all_graph_edges.count(single_edge),
            width=all_graph_edges.count(single_edge),
        )
    width_list = list(nx.get_edge_attributes(G, "width").values())
    nx.draw(
        G,
        with_labels=True,
        node_shape="o",
        width=width_list,
        connectionstyle="arc3, rad = 0.1",
        nodelist=["start", "end"],
    )
    plt.savefig(graph_img_name)
    plt.close()

def get_all_edges(all_paths: list[list]) -> list:
    graph_edges = list()
    for single_path in all_paths:
        graph_edges.append((single_path[0], single_path[1]))
        for ix, node in enumerate(graph_edges[1:-2]):
            graph_edges.append((node, single_path[ix + 1]))
    return graph_edges

if __name__ == "__main__":
    edges = get_all_edges([path1, path2, path3, path4, path5])
    draw_complete_graph(edges)

But this does not give me such a result.

My first thought was, is it even possible to do anything close to the graph linked above with Python3? If possible, please help me out and point out what else should I add to reach closer to the graph I require.


Solution

  • Here is an example using netgraph. The same can be achieved with just networkx but then you have to work out the order in which the nodes and edges are stored in the nx.Graph instance as networkx draw commands only accept lists, not dictionaries, for properties such as node size, edge width, node color, and edge color.

    Here, I am using a log(x+1) transformation to map node and edge traversals to node size/color and edge width/color. You may need to find a different transformation for your real data. The guiding principle is to map highly skewed distributions such as the traversal counts to flat distributions, in which the unique values that are present in your data are easily separable and don't cluster too much.

    enter image description here

    #!/usr/bin/env python
    """
    Reproduce https://www.pnas.org/doi/10.1073/pnas.0305937101#fig2
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import networkx as nx
    
    from matplotlib.colors import LogNorm
    from matplotlib.cm import ScalarMappable
    
    from netgraph import Graph, get_radial_tree_layout # pip install netgraph / conda install -c conda-forge netgraph
    
    if __name__ == '__main__':
    
        g = nx.balanced_tree(3, 3, create_using=nx.DiGraph)
        node_positions = get_radial_tree_layout(list(g.edges))
    
        # invert edges as the networkx constructor creates the tree inside-out
        h = nx.DiGraph([(v2, v1) for (v1, v2) in g.edges])
    
        # compute node and edge traversals
        node_traversals = {node : 0 for node in h}
        edge_traversals = {edge : 0 for edge in h.edges}
        for node in h:
            if node != 0: # center
                paths = nx.all_simple_paths(h, node, 0)
                for path in paths:
                    for node in path:
                        node_traversals[node] += 1
                    path_edges = list(zip(path[:-1], path[1:]))
                    for edge in path_edges:
                        edge_traversals[edge] += 1
    
        fig, (ax1, ax2) = plt.subplots(1, 2)
    
        # option 1: using node size and edge width
        node_size  = {node : np.log(count+1) for node, count in node_traversals.items()}
        edge_width = {edge : np.log(count+1) for edge, count in edge_traversals.items()}
    
        Graph(h, node_layout=node_positions, arrows=True,
              node_size=node_size, edge_width=edge_width, ax=ax1)
    
        # option 2: using colors
        node_colormap = ScalarMappable(norm=LogNorm(vmin=1, vmax=np.max(list(node_traversals.values()))), cmap='copper')
        edge_colormap = ScalarMappable(norm=LogNorm(vmin=1, vmax=np.max(list(edge_traversals.values()))), cmap='copper')
    
        node_color = {node : node_colormap.to_rgba(count) for node, count in node_traversals.items()}
        edge_color = {edge : edge_colormap.to_rgba(count) for edge, count in edge_traversals.items()}
        Graph(h, node_layout=node_positions, arrows=True,
              node_color=node_color, edge_color=edge_color, ax=ax2)
    
        plt.show()