Search code examples
pythonmatplotlibfiguresubplot

Reshape axes in figure using matplotlib


I am using a method from a library that used matplotlib to generate figures.

I receive an array of axes:

[<matplotlib.axes._axes.Axes at 0x117a32a90>,
 <matplotlib.axes._axes.Axes at 0x117bb1d68>,
 <matplotlib.axes._axes.Axes at 0x10bae8390>,
 <matplotlib.axes._axes.Axes at 0x10bb0add8>,
 <matplotlib.axes._axes.Axes at 0x10c153898>,
 <matplotlib.axes._axes.Axes at 0x1159412e8>,
 <matplotlib.axes._axes.Axes at 0x115964d30>]

In the original figure, all axes are in the same row (see first figure and imagine having additional two axes on the right side). I would like to reshape (à la numpy) the figure in order to create a grid of axes (see second figure).

a

b

Is it possible?

Update - What I tried

Following this answer, I tried to use GridSpec:

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

fig = plt.figure()

axs = #get list of axes

gs = gridspec.GridSpec(3,3)
for i in range(3):
    for j in range(3):
        k = i+j*3
        if k < len(axs):
            axs[k].set_position(gs[k].get_position(fig))    
            fig.add_subplot(gs[k])

But it does not work, and I have not a complete grasp of GridSpec yet. The figure displays the right number of subplots, but the axes are not added.


Solution

  • I think you are almost there. Without knowing what your plotting function is, I just made a dummy one for illustration.

    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    
    
    def dummy_plots():
        """
        Return a 1d array of dummy plots.
        """
        _, ax_arr = plt.subplots(1, 9)
    
        for ax in ax_arr.flat:
            ax.plot([0, 1], [0, 1])
    
        return ax_arr
    
    
    axs = dummy_plots()
    fig = plt.gcf()
    
    gs = gridspec.GridSpec(3,3)
    for i in range(3):
        for j in range(3):
            k = i+j*3
            if k < len(axs):
                axs[k].set_position(gs[k].get_position(fig))
    
    plt.show()
    

    enter image description here