Search code examples
pythonmatplotlib

How to reduce unnecessary white spaces in matplotlib subplot2grid?


I'm creating some plots with histograms and a color bar, but I'm struggling with the huge white gaps between subplots and I don't know how to reduce them. This an example of a more complex code:

import numpy as np
import matplotlib.pyplot as plt

ihist = np.array([
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
])

vhist = np.array([
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
])


mins = [1014.3484983803353, -168.92777938399416]
maxs = [5420.578637599565, 1229.7294914536292]
labels = ['x ($\\AA$)', 'y ($\\AA$)']

fig = plt.figure()
# SIAs
iax = plt.subplot2grid((9, 2), (1, 0), rowspan=8)
iax.set_xlabel(labels[0])
iax.set_xlim(mins[0], maxs[0])
iax.set_ylabel(labels[1])
iax.set_ylim(mins[1], maxs[1])
iax.imshow(ihist, origin="lower", extent=[mins[0], maxs[0], mins[1], maxs[1]])
# Vacancies
vax = plt.subplot2grid((9, 2), (1, 1), rowspan=8)
vax.set_xlabel(labels[0])
vax.set_xlim(mins[0], maxs[0])
vax.set_ylabel(labels[1])
vax.set_ylim(mins[1], maxs[1])
vax.yaxis.set_label_position("right")
vax.yaxis.tick_right()
vax_img = vax.imshow(vhist, origin="lower", extent=[mins[0], maxs[0], mins[1], maxs[1]])
# Color bar
cax = plt.subplot2grid((9, 2), (0, 0), colspan=2)
cbar = fig.colorbar(vax_img, cax=cax, orientation="horizontal")
cbar.set_label("Counts per ion")
cbar.ax.xaxis.set_ticks_position("top")
cbar.ax.xaxis.set_label_position("top")
plt.tight_layout()
plt.show()

And this is the output:

enter image description here

as you can see there are unnecessary white spaces between the color bar and the histograms, and between the histograms and the bottom of the figure. I want to remove them.

I think they are caused by things like plt.subplot2grid((9, 2), (1, 0), rowspan=8) but I did it that way to reduce the vertical size of the color bar.

Note: in the real code, limits are obtained on the fly, so histograms might have same hight and width or not.


Solution

  • This is due to the aspect ratio setting of imshow. If you use:

    ...
    iax.imshow(ihist, origin="lower", extent=[mins[0], maxs[0], mins[1], maxs[1]], aspect="auto")
    
    ...
    vax_img = vax.imshow(vhist, origin="lower", extent=[mins[0], maxs[0], mins[1], maxs[1]], aspect="auto")
    

    i.e., add aspect="auto", to your imshow calls, it should look as you expect. Note that you may want to switch to using add_gridspec and add_subplot, to allow finer degrees of control over the grid. E.g., an equivalent would be:

    ...
    fig = plt.figure()
    # SIAs
    
    # add a two by two grid using the required height ratios
    gs = fig.add_gridspec(
        2, 2, height_ratios=(1, 9)
    )
    
    iax = fig.add_subplot(gs[1, 0])
    iax.set_xlabel(labels[0])
    iax.set_xlim(mins[0], maxs[0])
    iax.set_ylabel(labels[1])
    iax.set_ylim(mins[1], maxs[1])
    iax.imshow(ihist, origin="lower", extent=[mins[0], maxs[0], mins[1], maxs[1]], aspect="auto")
    
    # Vacancies
    vax = fig.add_subplot(gs[1, 1])
    vax.set_xlabel(labels[0])
    vax.set_xlim(mins[0], maxs[0])
    vax.set_ylabel(labels[1])
    vax.set_ylim(mins[1], maxs[1])
    vax.yaxis.set_label_position("right")
    vax.yaxis.tick_right()
    vax_img = vax.imshow(vhist, origin="lower", extent=[mins[0], maxs[0], mins[1], maxs[1]], aspect="auto")
    
    # Color bar
    cax = fig.add_subplot(gs[0, :])
    cbar = fig.colorbar(vax_img, cax=cax, orientation="horizontal")
    cbar.set_label("Counts per ion")
    cbar.ax.xaxis.set_ticks_position("top")
    cbar.ax.xaxis.set_label_position("top")
    
    # also try constrained layout rather than tight layout
    fig.set_layout_engine("constrained")
    plt.show()