Search code examples
pythonmatplotlibplotaxes

Embedding several inset axes in another axis using matplotlib


Is it possible to embed a changing number of plots in a matplotlib axis? For example, the inset_axes method is used to place inset axes inside parent axes:

enter image description here

However, I have several rows of plots and I want to include some inset axes inside the last axis object of each row.

fig, ax = plt.subplots(2,4, figsize=(15,15))
for i in range(2):
    ax[i][0].plot(np.random.random(40))
    ax[i][2].plot(np.random.random(40))
    ax[i][3].plot(np.random.random(40))

    # number of inset axes
    number_inset = 5
    for j in range(number_inset):
        ax[i][4].plot(np.random.random(40))

enter image description here

Here instead of the 5 plots drawn in the last column, I want several inset axes containing a plot. Something like this:

enter image description here

The reason for this is that every row refers to a different item to be plotted and the last column is supposed to contain the components of such item. Is there a way to do this in matplotlib or maybe an alternative way to visualize this?

Thanks


Solution

  • As @hitzg mentioned, the most common way to accomplish something like this is to use GridSpec. GridSpec creates an imaginary grid object that you can slice to produce subplots. It's an easy way to align fairly complex layouts that you want to follow a regular grid.

    However, it may not be immediately obvious how to use it in this case. You'll need to create a GridSpec with numrows * numinsets rows by numcols columns and then create the "main" axes by slicing it with intervals of numinsets.

    In the example below (2 rows, 4 columns, 3 insets), we'd slice by gs[:3, 0] to get the upper left "main" axes, gs[3:, 0] to get the lower left "main" axes, gs[:3, 1] to get the next upper axes, etc. For the insets, each one is gs[i, -1].

    As a complete example:

    import numpy as np
    import matplotlib.pyplot as plt
    
    def build_axes_with_insets(numrows, numcols, numinsets, **kwargs):
        """
        Makes a *numrows* x *numcols* grid of subplots with *numinsets* subplots
        embedded as "sub-rows" in the last column of each row.
    
        Returns a figure object and a *numrows* x *numcols* object ndarray where
        all but the last column consists of axes objects, and the last column is a
        *numinsets* length object ndarray of axes objects.
        """
        fig = plt.figure(**kwargs)
        gs = plt.GridSpec(numrows*numinsets, numcols)
    
        axes = np.empty([numrows, numcols], dtype=object)
        for i in range(numrows):
            # Add "main" axes...
            for j in range(numcols - 1):
                axes[i, j] = fig.add_subplot(gs[i*numinsets:(i+1)*numinsets, j])
    
            # Add inset axes...
            for k in range(numinsets):
                m = k + i * numinsets
                axes[i, -1][k] = fig.add_subplot(gs[m, -1])
    
        return fig, axes
    
    def plot(axes):
        """Recursive plotting function just to put something on each axes."""
        for ax in axes.flat:
            data = np.random.normal(0, 1, 100).cumsum()
            try:
                ax.plot(data)
                ax.set(xticklabels=[], yticklabels=[])
            except AttributeError:
                plot(ax)
    
    fig, axes = build_axes_with_insets(2, 4, 3, figsize=(12, 6))
    plot(axes)
    fig.tight_layout()
    plt.show()
    

    enter image description here