Search code examples
pythonmatplotlibplotcolorbar

Set equal aspect in plot with colorbar


I need to generate a plot with equal aspect in both axis and a colorbar to the right. I've tried setting aspect='auto', aspect=1, and aspect='equal' with no good results. See below for examples and the MWE.

Using aspect='auto' the colorbars are of the correct height but the plots are distorted:

enter image description here

Using aspect=1 or aspect='equal' the plots are square (equal aspect in both axis) but the colorbars are distorted:

enter image description here

In both plots the colorbars are positioned too far to the right for some reason. How can I get a square plot with colorbars of matching heights?


MWE

import numpy as np
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

def col_plot(params):

    gs, i, data = params

    xarr, yarr, zarr = zip(*data)[0], zip(*data)[1], zip(*data)[2]

    xmin, xmax = min(xarr), max(xarr)
    ymin, ymax = min(yarr), max(yarr)

    #plt.subplot(gs[i], aspect='auto')
    plt.subplot(gs[i], aspect=1)
    #plt.subplot(gs[i], aspect='equal')

    plt.xlim(xmin, xmax)
    plt.ylim(xmin, xmax)
    plt.xlabel('$x axis$', fontsize=20)
    plt.ylabel('$y axis$', fontsize=20)
    # Scatter plot.
    cm = plt.cm.get_cmap('RdYlBu_r')
    SC = plt.scatter(xarr, yarr, marker='o', c=zarr, s=60, lw=0.25, cmap=cm,
        zorder=3)
    # Colorbar.
    ax0 = plt.subplot(gs[i + 1])
    cbar = plt.colorbar(SC, cax=ax0)
    cbar.set_label('$col bar$', fontsize=21, labelpad=-2)

# Generate data.
data0 = np.random.uniform(0., 1., size=(50, 3))
data1 = np.random.uniform(0., 1., size=(50, 3))

# Create the top-level container
fig = plt.figure(figsize=(14, 25))
gs = gridspec.GridSpec(4, 4, width_ratios=[1, 0.05, 1, 0.05])

# Generate plots.
par_lst = [[gs, 0, data0], [gs, 2, data1]]
for pl_params in par_lst:
    col_plot(pl_params)

# Output png file.
fig.tight_layout()
plt.savefig('colorbar_aspect.png', dpi=300)

Solution

  • You can use an AxesDivider to do that. I have modified your code a bit to make use of an AxesDivider.

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    
    def col_plot(data):
    
        xarr, yarr, zarr = zip(*data)[0], zip(*data)[1], zip(*data)[2]
        xarr = [2*x for x in xarr]
    
        xmin, xmax = min(xarr), max(xarr)
        ymin, ymax = min(yarr), max(yarr)
    
        fig = plt.figure()
    
        ax0 = fig.add_subplot(111, aspect='equal')
    
        plt.xlim(xmin, xmax)
        plt.ylim(ymin, ymax)
        plt.xlabel('$x axis$', fontsize=20)
        plt.ylabel('$y axis$', fontsize=20)
        # Scatter plot.
        cm = plt.cm.get_cmap('RdYlBu_r')
        SC = ax0.scatter(xarr, yarr, marker='o', c=zarr, s=60, lw=0.25, cmap=cm,
            zorder=3)
    
        the_divider = make_axes_locatable(ax0)
        color_axis = the_divider.append_axes("right", size="5%", pad=0.1)
    
        # Colorbar.
        cbar = plt.colorbar(SC, cax=color_axis)
        cbar.set_label('$col bar$', fontsize=21, labelpad=-2)
    
    # Generate data.
    data0 = np.random.uniform(0., 1., size=(20, 3))
    
    col_plot(data0)
    

    And here is the result (I changed your data so it spans a range of [0, 2] in the x-direction for demonstration purposes):enter image description here