Search code examples
pythonpython-3.xmatplotlibplotfigure

How to plot a list of figures in a single subplot?


I have 2 lists of figures and their axes. I need to plot each figure in a single subplot so that the figures become in one big subplot. How can I do that?

I tried for loop but it didn't work.

Here's what I have tried:

import ruptures as rpt
import matplotlib.pyplot as plt

# make random data with 100 samples and 9 columns 
n_samples, n_dims, sigma = 100, 9, 2
n_bkps = 4
signal, bkps = rpt.pw_constant(n_samples, n_dims, n_bkps, noise_std=sigma)

figs, axs = [], []
for i in range(signal.shape[1]):
    points = signal[:,i]
    # detection of change points 
    algo = rpt.Dynp(model='l2').fit(points)
    result = algo.predict(n_bkps=2)
    fig, ax = rpt.display(points, bkps, result, figsize=(15,3))
    figs.append(fig)
    axs.append(ax)
    plt.show()

enter image description here

enter image description here


Solution

  • I had a look at the source code of ruptures.display(), and it accepts **kwargs that are passed on to matplotlib. This allows us to redirect the output to a single figure, and with gridspec, we can position individual subplots within this figure:

    import ruptures as rpt
    import matplotlib.pyplot as plt
    
    n_samples, n_dims, sigma = 100, 9, 2
    n_bkps = 4
    signal, bkps = rpt.pw_constant(n_samples, n_dims, n_bkps, noise_std=sigma)
    
    #number of subplots
    n_subpl = signal.shape[1]
    #give figure a name to refer to it later
    fig = plt.figure(num = "ruptures_figure", figsize=(8, 15))
    #define grid of nrows x ncols
    gs = fig.add_gridspec(n_subpl, 1)
    
    
    for i in range(n_subpl):
        points = signal[:,i]
        algo = rpt.Dynp(model='l2').fit(points)
        result = algo.predict(n_bkps=2)
        #rpt.display(points, bkps, result)
        #plot into predefined figure
        _, curr_ax = rpt.display(points, bkps, result, num="ruptures_figure")
        #position current subplot within grid
        curr_ax[0].set_position(gs[i].get_position(fig))
        curr_ax[0].set_subplotspec(gs[i])   
    
    plt.show()
    

    Sample output:

    enter image description here