Search code examples
pythonnetworkx

How to visualize networkx graph to track a population of id's over time?


I am very new to networkx, apologies if this is very easy. But I am trying to visualize the evolution of a population over time using a networkx graph, how would that be accomplished?

The closest thing I have found in their docs, visually, is this multipartite layout: https://networkx.org/documentation/stable/auto_examples/drawing/plot_multipartite_graph.html#sphx-glr-auto-examples-drawing-plot-multipartite-graph-py

The graph of which looks like this: enter image description here

Code from their example:

subset_sizes = [5, 5, 4, 3, 2, 4, 4, 3]
subset_color = [
    "gold",
    "violet",
    "violet",
    "violet",
    "violet",
    "limegreen",
    "limegreen",
    "darkorange",
]


def multilayered_graph(*subset_sizes):
    extents = nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes))
    layers = [range(start, end) for start, end in extents]
    G = nx.Graph()
    for (i, layer) in enumerate(layers):
        G.add_nodes_from(layer, layer=i)
    for layer1, layer2 in nx.utils.pairwise(layers):
        G.add_edges_from(itertools.product(layer1, layer2))
    return G


G = multilayered_graph(*subset_sizes)
color = [subset_color[data["layer"]] for v, data in G.nodes(data=True)]
pos = nx.multipartite_layout(G, subset_key="layer")
plt.figure(figsize=(8, 8))
nx.draw(G, pos, node_color=color, with_labels=False)
plt.axis("equal")
plt.show()

As an example of what I am looking for, assume in period 1 we have 3 id's:

[1,2,3]

In period 2, we have 4 id's:

[1,2,4,5]

Notice how we have 2 new entrants into the list , and one exit.

I am envisioning a multipartite graph like the one described in the docs, except where each layer's node is connected via one edge just to itself in the next layer, each layer representing a period in a timeseries. Each layers number of nodes, oriented vertically, might increase or decrease depending if more nodes are added than removed, but you might get a sense of how the population changes over time by visualizing it this way.

As a bonus... Each time a new node is added in a given period it is green (1st layer all green), and every time a node exits (is not present in layer+1) , it would be red, otherwise an empty circle/normal node (last layer all empty circles since we don't know layer+1).


Solution

  • The multipartite example is handy, but your problem is complicated by only wanting to connect nodes that have the same id. NetworkX requires unique node names, so you have to work around that.

    There are multiple ways of approaching this and this solution is likely not optimal, but the following should accomplish what you want! I've used "P" to designate the time periods and refer to the nodes as groups. Along with the things you likely already have imported, I also used defaultdict from the collections module for generating a dictionary that can be created "on-the-fly" without needing to check if a key exists or not before writing it.

    I'm assuming you have or can format the data in a way similar to the period_groups dictionary I've defined. I added one additional period to your example to show how that would work. Is there ever a chance that a group could disappear and reappear in the future? If that's the case, this unfortunately won't work...

    from collections import defaultdict
    
    # Dictionary with period names as keys and the id numbers of groups as values 
    period_groups = {
        'P1':[1, 2, 3],
        'P2':[1, 2, 4, 5, 8],
        'P3':[1, 4, 5, 6, 7]
        }
    
    # Dictionary that will be populated with group id  numbers as keys and periods as values.
    # Makes it easier to add edges between nodes later.
    group_periods = defaultdict(list)
    
    maxperiod = len(period_groups) # Used later for coloring nodes
    
    G = nx.Graph()
    for i, (period, groups) in enumerate(period_groups.items()): # iterator integer, key, values
        for group_number in groups:
            nodename = '{}_{}'.format(period, group_number) # e.g. 'P1_1'
            group_periods[group_number].append(nodename) # Add to new dictionary
            G.add_node(nodename,
                pos = (i, group_number), # Added explicitly, rather than using multipartite layout.
                group = group_number, # For labeling later (if desired)
                color = 'blue', # Default color
                period = i+1 # Matches indexing between periods and integer iterator.
                )
    
    # pairwise uses a list (e.g. [1, 2, 3]) to create a set of tuples ((1, 2), (2, 3))
    # So edges is a set of all nodes that should be connected and combines them into one list.
    # e.g. for group 1, it creates (('P1_1', 'P2_1'), ('P2_1', 'P3_1'))
    # The nested list comprehension is a bit confusing, but this post helps explain it:
    # https://stackoverflow.com/questions/18072759/comprehension-on-a-nested-iterables
    
    edges = [i for periods in group_periods.values() for i in nx.utils.pairwise(periods)]
    G.add_edges_from(edges) # Adds the edges created above
    
    
    # The following is the coloring scheme that I chose, but can be easily changed!
    # The first node in a group is green if the group exists for more than one period.
    # If a node exists for only one period, it is gray (unclear if you wanted red or green for this).
    # If a group is included in the last (overall) period, it's final node is light gray.
    # If a group ends before the last (overall) period, it's final node is red.
    # If a group is in an "intermediate" period, it remains blue as originally assigned.
    
    for group, periods in group_periods.items():
        G.nodes[periods[-1]]['color'] = 'lightgray'
        if len(periods) == 1 and G.nodes[periods[-1]]['period'] < maxperiod:
            G.nodes[periods[0]]['color'] = 'gray'
        elif len(periods) > 1 and G.nodes[periods[-1]]['period'] < maxperiod:
            G.nodes[periods[0]]['color'] = 'green'
            G.nodes[periods[-1]]['color'] = 'red'   
        elif len(periods) > 1:
            G.nodes[periods[0]]['color'] = 'green'
    
    nx.draw(G,
        pos = nx.get_node_attributes(G, 'pos'), # Position is defined by a dictionary.
        node_color = [G.nodes[node]['color'] for node in G.nodes()], # Color is from a list.
        labels = nx.get_node_attributes(G,'group')) # Labeled using original group id.
        
    plt.show()
    

    Produces the following graph: NetworkX graph of populations over time


    Better code edit! For some reason, I've never actually used the NetworkX add_path method in the past, but it essentially does the same thing as utils.pairwise in a simpler way and avoids the use of the more confusing nested list comprehension.

    With that in mind, I'd change those two lines of code to the following:

    for periods in group_periods.values(): # Just the lists of periods
        nx.add_path(G, periods)