Search code examples
pythonperformancematplotlibanimationnetworkx

Making animation on matplotlib of graph with Networkx efficiently


I am trying to animate a graph whose edges widths and color change over time. My code works, but it is extremely slow. I imagine there are more efficient implementations.

def minimal_graph(datas, pos):
    frames = len(datas[0])
    fig, axes = plt.subplots(1, 2)
    axes = axes.flatten()
    for j, dat in enumerate(datas):
        G = nx.from_numpy_matrix(dat[0])
        nx.draw(G, pos, ax=axes[j])
    def update(it, data, pos, ax):
        print(it)
        for i, dat in enumerate(data):
            # This is the problematic line, because I clear the axis hence
            # everything has to be drawn from scratch every time.
            ax[i].clear()
            G = nx.from_numpy_matrix(dat[it])
            edges, weights = zip(*nx.get_edge_attributes(G, 'weight').items())
            nx.draw(
                G,
                pos,
                node_color='#24FF00',
                edgelist=edges,
                edge_color=weights,
                width=weights,
                edge_vmin=-5,
                edge_vmax=5,
                ax=ax[i])
    ani = animation.FuncAnimation(fig, update, frames=frames, fargs=(
        datas, pos, axes), interval=100)
    ani.save('temp/graphs.mp4')
    plt.close()


dataset1 = []
dataset2 = []
for i in range(100):
    arr1 = np.random.rand(400, 400)
    arr2 = np.random.rand(400, 400)
    dataset1.append(arr1)
    dataset2.append(arr2)

datasets = [dataset1, dataset2]
G = nx.from_numpy_matrix(dataset1[0])
pos = nx.spring_layout(G)

minimal_graph(datasets, pos)

As pointed out in the code, the problem is that at every frame I redraw the graph from "scratch". When using animations in matplotlib, I usually try to create lines and use the function '''line.set_data()''', which I know is a lot faster. It's just that I don't know how to set that for a graph using networkx. I found this question here, but they also use the same ax.clear and redraw everything for every frame. So, is there a way to set a line object to not redraw everything every iteration? For example, in my case the nodes are always the same (color, location, size stay the same).


Solution

  • nx.draw does not expose the matplotlib artists used to represent the nodes and edges, so you cannot alter the properties of the artists in-place. Technically, if you plot the edges separately, you do get some collection of artists back but it is non-trivial to map the list of artists back to the edges, in particular if there are self-loops present.

    If you are open for using other libraries to make the animation, I wrote netgraph some time ago. Crucially to your problem, it exposes all artists in easily to index forms such that their properties can be altered in-place and without redrawing everything else. netgraph accepts both full-rank matrices and networkx Graph objects as inputs so it should be simple to feed in your data.

    Below is a simple example visualization. If I run the same script with with 400 nodes and 1000 edges, it needs 30 seconds to complete on my laptop.

    #!/usr/bin/env python
    """
    MWE for animating edges.
    """
    import numpy as np
    import matplotlib.pyplot as plt
    
    from netgraph import Graph # pip install netgraph
    from matplotlib.animation import FuncAnimation
    
    total_nodes = 10
    total_frames = 100
    
    adjacency_matrix = np.random.rand(total_nodes, total_nodes) < 0.2
    weight_matrix = 5 * np.random.randn(total_frames, total_nodes, total_nodes)
    
    # precompute normed weight matrix, such that weights are on the interval [0, 1];
    # weights can then be passed directly to matplotlib colormaps (which expect float on that interval)
    vmin, vmax = -5, 5
    weight_matrix[weight_matrix<vmin] = vmin
    weight_matrix[weight_matrix>vmax] = vmax
    weight_matrix -= vmin
    weight_matrix /= vmax - vmin
    
    cmap = plt.cm.RdGy
    
    plt.ion()
    
    fig, ax = plt.subplots()
    g = Graph(adjacency_matrix, arrows=True, ax=ax)
    
    def update(ii):
        artists = []
        for jj, kk in zip(*np.where(adjacency_matrix)):
            w = weight_matrix[ii, jj, kk]
            g.edge_artists[(jj, kk)].set_facecolor(cmap(w))
            g.edge_artists[(jj, kk)].width = 0.01 * np.abs(w-0.5) # assuming large negative edges should be wide, too
            g.edge_artists[(jj, kk)]._update_path()
            artists.append(g.edge_artists[(jj, kk)])
        return artists
    
    animation = FuncAnimation(fig, update, frames=total_frames, interval=100, blit=True)
    animation.save('edge_animation.mp4')
    

    enter image description here