Search code examples
pythonmatplotlibplot

Updating a figure with multiple subplots during a loop where subplot must contains image/matrix with colorbar


I am trying to carry out a simulation where I perform calculation over some matrices which gets updated over the loop. Now I want to see their evolution over time so I plotted them with respective colorbars within the loop. All was ok till I didn't use the colorbar. As soon as I used colorbar the subplots stopped displaying. I tried several option and one possibility was to omit plt.clf(). It seems to work for small number of iteration, however without it the simulation gets slow as the fig variable kept storing images (see len(plt.gca().images). I am looking for a solution of above problem without slowing down execution speed.

import numpy as np
import matplotlib.pyplot as plt
import time

def mysubplots(ax,mat,title):
    im00 = ax.imshow(mat);
    ax.set_title(title);
    cax = plt.colorbar(im00, ax=ax);
    return cax

(n,m) = (200,200)
A = np.random.random((n,m))
B = np.random.random((n,m))
C = np.random.random((n,m))

plt.close()
fig,ax = plt.subplots(2,2, sharex=True, sharey=True,figsize=(5,5))
caxa = mysubplots(ax[0,0], A, 'A')
caxb = mysubplots(ax[0,1], B, 'B')
caxc = mysubplots(ax[1,0], C, 'C')
                      
for i in range(500):
    A = i*np.random.random((n,m))
    B = i*np.random.random((n,m))
    C = i*np.random.random((n,m))
       
    caxa.remove()
    caxb.remove()
    caxc.remove()
    plt.clf()
    
    caxa = mysubplots(ax[0,0], A, 'A')
    caxb = mysubplots(ax[0,1], B, 'B')
    caxc = mysubplots(ax[1,0], C, 'C')
    plt.suptitle(f'Iteration {i}')
        
    fig.canvas.draw()
    fig.canvas.flush_events()
    time.sleep(.1)

PS: I changed the scale of the data so that it increases loop progress.


Solution

  • There are some problems in your code, in terms of how you want to update the figure: plt.clf() --> fig.canvas.draw() only works if you do not have subplots. plt.clf() clears the whole figure, which includes deleting the definition of the subplots and the axes.

    A more viable approach would be clearing the content of the axis itself, using ax.cla().

    You can use the following code to do what you want:

    import numpy as np
    %matplotlib notebook
    import matplotlib.pyplot as plt
    import time
    
    
    (n,m) = (200,200)
    A = np.random.random((n,m))
    B = np.random.random((n,m))
    C = np.random.random((n,m))
    
    def update_subplot(ax, data, title, i, cmap='viridis'):
        ax.cla()
        im = ax.imshow(data)
        if i == 0:
            plt.colorbar(im, ax = ax)
        ax.set_title(title)
        
    fig, ax = plt.subplots(2, 2, sharex=True, sharey=True,figsize=(8,8))
    
    for i in range(500):
        
        A = np.random.random((n,m))
        B = np.random.random((n,m))
        C = np.random.random((n,m))
        
        update_subplot(ax[0,0], A, "A", i)
        update_subplot(ax[0,1], B, "B", i)
        update_subplot(ax[1,0], C, "C", i)
        
        plt.suptitle(f'Iteration {i}')
        fig.canvas.draw()
        time.sleep(.1)
    

    One thing to consider in this approach: colorbar is not updated, it is only drawn once when i == 0. To update the colorbar as well, try assigning the colorbar and then using cbar.remove() to delete it if it exists

    V2.0: added update of colorbar

    So as I mentioned before, the colorbar would not be updated, since it is only done once. In the following code, the function checks if the colorbar exists, if so, it deletes it and then plots it again. This ensures that you don't have 500 colorbars at the end of the for loop.

    (n,m) = (200,200)
    
    def update_subplot(ax, data, title, cbar, cmap='viridis'):
        ax.cla()
        im = ax.imshow(data)
        if cbar is None:
            cbar = plt.colorbar(im, ax=ax)
        else:
            cbar.remove()
            cbar = plt.colorbar(im, ax=ax)
        ax.set_title(title)
        return cbar
        
    fig, ax = plt.subplots(2, 2, sharex=True, sharey=True,figsize=(8,8))
    cbar00 = None
    cbar01 = None
    cbar10 = None
    
    for i in range(10):
        
        A = np.random.random((n,m))*i
        B = np.random.random((n,m))*i
        C = np.random.random((n,m))*i
        
        cbar00 = update_subplot(ax[0,0], A, "A", cbar00)
        cbar01 = update_subplot(ax[0,1], B, "B", cbar01)
        cbar10 = update_subplot(ax[1,0], C, "C", cbar10)
        
        plt.suptitle(f'Iteration {i}')
        fig.canvas.draw()
        time.sleep(.1)
    

    Here is an animation of the above code, where I multiplied ABC with i, such that it is always increasing:

    Animation

    V3.0: how to get it to work in spyder

    Instead of using the notebook backend, tk can be used for the matplotlib library. Furthermore, it's necessary to flush events after canvas.draw:

    import numpy as np
    %matplotlib tk
    import matplotlib.pyplot as plt
    import time
    
    (n,m) = (200,200)
    
    def update_subplot(ax, data, title, cbar, cmap='viridis'):
    ax.cla()
    im = ax.imshow(data)
    if cbar is None:
    cbar = plt.colorbar(im, ax=ax)
    else:
    cbar.remove()
    cbar = plt.colorbar(im, ax=ax)
    ax.set_title(title)
    return cbar
    
    fig, ax = plt.subplots(2, 2, sharex=True, sharey=True,figsize=(8,8))
    cbar00 = None
    cbar01 = None
    cbar10 = None
    
    for i in range(10):
    
    A = np.random.random((n,m))*i
    B = np.random.random((n,m))*i
    C = np.random.random((n,m))*i
    
    cbar00 = update_subplot(ax[0,0], A, "A", cbar00)
    cbar01 = update_subplot(ax[0,1], B, "B", cbar01)
    cbar10 = update_subplot(ax[1,0], C, "C", cbar10)
    
    plt.suptitle(f'Iteration {i}')
    fig.canvas.draw()
    fig.canvas.flush_events()
    time.sleep(.1)
    

    Still the command window running left

    V4.0: running from a command window (disclaimer: I am running it from linux)

    import numpy as np
    import matplotlib.pyplot as plt
    import time
    
    
    (n,m) = (200,200)
    
    def update_subplot(ax, data, title, cbar, cmap='viridis'):
        ax.cla()
        im = ax.imshow(data)
        if cbar is None:
            cbar = plt.colorbar(im, ax=ax)
        else:
            cbar.remove()
            cbar = plt.colorbar(im, ax=ax)
        ax.set_title(title)
        return cbar
        
    fig, ax = plt.subplots(2, 2, sharex=True, sharey=True,figsize=(8,8))
    cbar00 = None
    cbar01 = None
    cbar10 = None
    
    for i in range(100):
        print(i)
        A = np.random.random((n,m))*i
        B = np.random.random((n,m))*i
        C = np.random.random((n,m))*i
        
        cbar00 = update_subplot(ax[0,0], A, "A", cbar00)
        cbar01 = update_subplot(ax[0,1], B, "B", cbar01)
        cbar10 = update_subplot(ax[1,0], C, "C", cbar10)
        
        plt.suptitle(f'Iteration {i}')
        plt.pause(0.1)
    

    Here, the fig methods did not really work for me, and the backend choice did not have an effect on displaying the images. plt.pause() did the trick though.