Search code examples
pythonmatplotlibcolorbar

matplotlib adjacent subplots: adding colorbar changes size of a subplot


I am trying to plot a figure consisting of 5 x 6 subplots, all of which I want to be adjacent, i.e. share their x and y axes.

I also want to add a colorbar to the rightmost plot in each row, which is normalized for all the subplots in that row.

I add the colorbar using:

   divider = make_axes_locatable(ax)
   cax = divider.append_axes("right", size="5%", pad=0.15)
   fig.colorbar(im, cax=cax)

However, adding the colorbar changes the size of the subplot, and it's not sticking to the other subplots any more. Here is a sample output: sample output, generated with the program below How can I get it to not change dimensions when I'm adding a colorbar?

Here is the example for the code that generates the image above:

#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size 

# border limits for plots
lowlim = 0.43
uplim = 0.52
nx = 10


kernels = ['cubic_spline', 'quintic_spline',
        'wendland_C2', 'wendland_C4', 'wendland_C6']




#========================
def main():
#========================


    eta_facts = [ 0, 1, 2, 3, 4, 5 ]

    nrows = len(eta_facts)
    ncols = len(kernels)

    Ay_list = [[None for c in range(ncols)] for r in range(nrows)]



    #--------------------------------
    # Loop and compute As
    #--------------------------------
    dx = (uplim - lowlim)/nx

    for row, eta in enumerate(eta_facts):

        for col, kernel in enumerate(kernels):

            A = np.zeros((nx, nx), dtype=np.float)

            for i in range(nx):
                for j in range(nx):
                    A[j,i] = row + np.random.random()/10 # not a typo: need A[j,i] for imshow

            Ay_list[row][col] = A




    #------------------------------------
    # Now plot it
    #------------------------------------

    fig = plt.figure(figsize=(3.5*ncols, 3.5*nrows))

    axrows = []
    i = 1
    for r in range(nrows):
        axcols = []
        for c in range(ncols):
            if r > 0:
                if c > 0:
                    axcols.append(fig.add_subplot(nrows, ncols, i, 
                            aspect='equal', sharex=axrows[r-1][c], sharey=axcols[c-1]))
                else:
                    axcols.append(fig.add_subplot(nrows, ncols, i, 
                            aspect='equal', sharex=axrows[r-1][c]))
            else:
                if c > 0:
                    axcols.append(fig.add_subplot(nrows, ncols, i, 
                            aspect='equal', sharey=axcols[c-1]))
                else:
                    axcols.append(fig.add_subplot(nrows, ncols, i, aspect='equal'))
            i+=1
        axrows.append(axcols)



    cmap = 'YlGnBu_r'

    lw = 2


    for row in range(nrows):
        axcols = axrows[row]

        minval = min([np.min(Ay_list[row][c]) for c in range(ncols)])
        maxval = max([np.max(Ay_list[row][c]) for c in range(ncols)])


        for col, ax in enumerate(axcols):

            im = ax.imshow(Ay_list[row][col], origin='lower', 
                vmin=minval, vmax=maxval, cmap=cmap,
                extent=(lowlim, uplim, lowlim, uplim),
                #  norm=matplotlib.colors.SymLogNorm(1e-3),
                zorder=1)

            # only plot colorbar for last column

            if col==len(kernels)-1:
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", size="5%", pad=0.15)
                fig.colorbar(im, cax=cax)


            ax.set_xlim((lowlim,uplim))
            ax.set_ylim((lowlim,uplim))


            # cosmetics
            if col > 0:
                left = False
            else:
                left = True
            if row == len(eta_facts)-1 :
                bottom = True
            else:
                bottom = False

            ax.tick_params(
                axis='both',        # changes apply to the x-axis
                which='both',       # both major and minor ticks are affected
                bottom=bottom,      # ticks along the bottom edge are off
                top=False,          # ticks along the top edge are off
                left=left,          # ticks along the left edge are off
                right=False,        # ticks along the rigt edge are off
                labelbottom=bottom, # labels along the bottom edge are off
                labeltop=False,     # labels along the top edge are off
                labelleft=left,     # labels along the left edge are off
                labelright=False)   # labels along the right edge are off


            if row==0:
                ax.set_title(kernels[col] + ' kernel', fontsize=14)
            if col==0:
                ax.set_ylabel(r"$\eta = $ "+str(eta_facts[row])+r"$\eta_0$")


    fig.suptitle(r"Some title", fontsize=18)
    plt.tight_layout(rect=(0, 0, 1, 0.97))
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.savefig('for_stackexchange.png', dpi=150)
    plt.close()

    print('finished.')

    return





if __name__ == '__main__':
    main()

Solution

  • Indeed, the following the comments from user ImportanceOfBeingErnest and using axes_grid1, I could make it work.

    The main difference: For each row, I define a ImageGrid object:

        axrows = [[] for r in range(nrows)]
        i = 0
        for r in range(nrows):
            axcols = [None for c in range(ncols)]
    
            axcols = ImageGrid(fig, (nrows, 1, r+1),
                        nrows_ncols=(1, ncols), 
                        axes_pad = 0.0,
                        share_all = True,
                        label_mode = 'L',
                        cbar_mode = 'edge',
                        cbar_location = 'right',
                        cbar_size = "7%",
                        cbar_pad = "2%")
            axrows[r] = axcols
    
    

    And then, at the right place in the code, I add the colorbar with

    axcols.cbar_axes[0].colorbar(im)
    

    Here is the full code:

    #!/usr/bin/env python3
    
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.colors
    from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size 
    from mpl_toolkits.axes_grid1 import ImageGrid
    
    # border limits for plots
    lowlim = 0.43
    uplim = 0.52
    nx = 10
    
    
    kernels = ['cubic_spline', 'quintic_spline',
            'wendland_C2', 'wendland_C4', 'wendland_C6']
    
    
    
    
    #========================
    def main():
    #========================
    
    
        eta_facts = [ 0, 1, 2, 3, 4, 5 ]
    
        nrows = len(eta_facts)
        ncols = len(kernels)
    
        Ay_list = [[None for c in range(ncols)] for r in range(nrows)]
    
    
    
        #--------------------------------
        # Loop and compute As
        #--------------------------------
        dx = (uplim - lowlim)/nx
    
        for row, eta in enumerate(eta_facts):
    
            for col, kernel in enumerate(kernels):
    
                A = np.zeros((nx, nx), dtype=np.float)
    
                for i in range(nx):
                    for j in range(nx):
                        A[j,i] = row + np.random.random()/10 # not a typo: need A[j,i] for imshow
    
                Ay_list[row][col] = A
    
    
    
    
        #------------------------------------
        # Now plot it
        #------------------------------------
    
        fig = plt.figure(figsize=(3.5*ncols+1, 3.5*nrows))
    
    
    
        axrows = [[] for r in range(nrows)]
        i = 0
        for r in range(nrows):
            axcols = [None for c in range(ncols)]
    
            axcols = ImageGrid(fig, (nrows, 1, r+1),
                        nrows_ncols=(1, ncols), 
                        axes_pad = 0.0,
                        share_all = True,
                        label_mode = 'L',
                        cbar_mode = 'edge',
                        cbar_location = 'right',
                        cbar_size = "7%",
                        cbar_pad = "2%")
            axrows[r] = axcols
    
    
    
        cmap = 'YlGnBu_r'
    
        lw = 2
    
    
        for row in range(nrows):
            axcols = axrows[row]
    
            minval = min([np.min(Ay_list[row][c]) for c in range(ncols)])
            maxval = max([np.max(Ay_list[row][c]) for c in range(ncols)])
    
    
            for col, ax in enumerate(axcols):
    
                im = ax.imshow(Ay_list[row][col], origin='lower', 
                    vmin=minval, vmax=maxval, cmap=cmap,
                    extent=(lowlim, uplim, lowlim, uplim),
                    #  norm=matplotlib.colors.SymLogNorm(1e-3),
                    zorder=1)
    
                ax.set_xlim((lowlim,uplim))
                ax.set_ylim((lowlim,uplim))
    
    
                # cosmetics
                if col > 0:
                    left = False
                else:
                    left = True
                if row == len(eta_facts)-1 :
                    bottom = True
                else:
                    bottom = False
    
                ax.tick_params(
                    axis='both',        # changes apply to the x-axis
                    which='both',       # both major and minor ticks are affected
                    bottom=bottom,      # ticks along the bottom edge are off
                    top=False,          # ticks along the top edge are off
                    left=left,          # ticks along the left edge are off
                    right=False,        # ticks along the rigt edge are off
                    labelbottom=bottom, # labels along the bottom edge are off
                    labeltop=False,     # labels along the top edge are off
                    labelleft=left,     # labels along the left edge are off
                    labelright=False)   # labels along the right edge are off
    
    
                if row==0:
                    ax.set_title(kernels[col] + ' kernel', fontsize=14)
                if col==0:
                    ax.set_ylabel(r"$\eta = $ "+str(eta_facts[row])+r"$\eta_0$")
            axcols.cbar_axes[0].colorbar(im)
    
    
        fig.suptitle(r"Some title", fontsize=18)
        plt.tight_layout(rect=(0, 0, 1, 0.97))
        plt.subplots_adjust(wspace=0.0, hspace=0.0)
        plt.savefig('for_stackexchange.png', dpi=150)
        plt.close()
    
        print('finished.')
    
        return
    
    
    
    
    
    if __name__ == '__main__':
        main()
    

    Which produces this image:

    results