Search code examples
pythonmatplotlibsubplot

Set absolute size of subplots


I know how to set the relative size of subplots within a figure using gridspec or subplots_adjust, and I know how to set the size of a figure using figsize. My problem is setting the absolute size of the subplots.

Use case: I am making two separate plots which will be saved as pdfs for an academic paper. One has two subplots and one has three subplots (in both cases in 1 row). I need each of the 5 subplots to be the exact same size with the exact same font sizes (axis labels, tick labels, etc) in the resulting PDFs. In the example below the fonts are the same size but the subplots are not. If I make the height of the resulting PDFs the same (and thus the axes), the font on 3-subplots.pdf is smaller than that of 2-subplots.pdf.

MWE:

import matplotlib.pyplot as plt

subplots = [2, 3]
for i, cols in enumerate(subplots):

    fig, ax = plt.subplots(1, cols, sharey=True, subplot_kw=dict(box_aspect=1))

    for j in range(cols):
        ax[j].set_title(f'plot {j*cols}')
        ax[j].set_xlabel('My x label')
    ax[0].set_ylabel('My y label')

    plt.tight_layout()
    plt.savefig(f'{cols}-subplots.pdf', bbox_inches='tight', pad_inches=0)
    plt.show()

Output: output


Solution

  • I ended up solving this by:

    1. setting explicit absolute lengths for subplot width/height, the space between subplots and the space outside subplots,
    2. adding them up to get an absolute figure size,
    3. setting the subplot box_aspect to 1 to keep them square.
    import matplotlib.pyplot as plt
    
    num_subplots = [2, 3]
    
    scale = 1 # scaling factor for the plot
    subplot_abs_width = 2*scale # Both the width and height of each subplot
    subplot_abs_spacing_width = 0.2*scale # The width of the spacing between subplots
    subplot_abs_excess_width = 0.3*scale # The width of the excess space on the left and right of the subplots
    subplot_abs_excess_height = 0.3*scale # The height of the excess space on the top and bottom of the subplots
    
    for i, cols in enumerate(num_subplots):
        fig_width = (cols * subplot_abs_width) + ((cols-1) * subplot_abs_spacing_width) + subplot_abs_excess_width
        fig_height = subplot_abs_width+subplot_abs_excess_height
    
        fig, ax = plt.subplots(1, cols, sharey=True, figsize=(fig_width, fig_height), subplot_kw=dict(box_aspect=1))
    
        for j in range(cols):
            ax[j].set_title(f'plot {j}')
            ax[j].set_xlabel('My x label')
        ax[0].set_ylabel('My y label')
    
        plt.tight_layout()
        plt.savefig(f'{cols}-subplots.pdf', bbox_inches='tight', pad_inches=0)
        plt.show()
    

    Picture of solution