Search code examples
pythonmatplotlibsubplotmultiple-axes

Matplotlib external axes ruin subplot layout


I would like to plot a figure with 3 subplots. The middle one has 3 different x-axes, one of which is detached and placed below the subplot. When I use Gridspec for the layout the plot areas are spaced equidistantly, but the padding between the axes labels of different subplots is hugely different:

Here is the code to reproduce the figure:

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(3.375, 6.5))
gs0 = gridspec.GridSpec(3, 1, figure=fig)

ax0 = fig.add_subplot(gs0[0])
ax0.set_xlabel('x label 0')

ax1 = fig.add_subplot(gs0[1])
ax1.set_xlabel('x label 1a')
secax1 = ax1.twiny()
secax1.xaxis.set_ticks_position('bottom')
secax1.xaxis.set_label_position('bottom')
secax1.spines['bottom'].set_position(('outward', 40))
secax1.set_xlabel('x label 1b')
thax1 = ax1.twiny()
thax1.set_xlabel('x label 1c')

ax2 = fig.add_subplot(gs0[2])
ax2.set_xlabel('x label 2a')
ax2.set_ylabel('y label 2')
secax2 = ax2.twiny()
secax2.set_xlabel('x label 2b')

plt.tight_layout()
plt.savefig('3 subplots same size.png', dpi=300)
plt.show()

I'm looking for a way to either make the spacing between the complete subfigures equal, with everything like the additional axes and their labels. Or a way to manually shift the subplots within the grid. The subplots don't need to maintain the same size.

I tried changing the height_ratios as

gs0 = gridspec.GridSpec(3, 1, figure=fig, height_ratios=[1, 1.5, 1])

but it doesn't affect the spaces between the plots.


Solution

  • This kind of thing is where constrained_layout does better than tight_layout. tight_layout only allows one margin size, so makes lots of room between the rows of subplots. constrained_layout keeps one upper and lower margin per row of the gridspec.

    Yes, constrained_layout is marked experimental. That is so its behaviour can be changed without warning. But the API is not likely to change.

    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    
    fig, (ax0, ax1, ax2)  = plt.subplots(3,1, figsize=(3.375, 6.5), 
                                       constrained_layout=True)
    
    ax0.set_xlabel('x label 0')
    
    ax1.set_xlabel('x label 1a')
    secax1 = ax1.twiny()
    secax1.xaxis.set_ticks_position('bottom')
    secax1.xaxis.set_label_position('bottom')
    secax1.spines['bottom'].set_position(('outward', 40))
    secax1.set_xlabel('x label 1b')
    thax1 = ax1.twiny()
    thax1.set_xlabel('x label 1c')
    
    ax2.set_xlabel('x label 2a')
    ax2.set_ylabel('y label 2')
    secax2 = ax2.twiny()
    secax2.set_xlabel('x label 2b')
    
    

    CLversion