Search code examples
pythonmatplotlibsubplotimshow

Align subplot with colorbar


I am trying to share the x-axis of a imshow that have to be square and a classique plot:

  1. the imshow has to be square
  2. with a colorbar
  3. the plot bellow should share the same axis (or at least look like align with the imshow)

I spent two days on it, and now I am crazy. Did someone know how to align them ?

square imshow on the top, with a color bar, and bellow a plot, that should share the same axis

The code used to produce the image is bellow.

def myplot( Nbin=20 ):

X = np.random.rand(1000)
Y = np.random.rand(1000)
h2, yh2, xh2 = np.histogram2d( Y, X, bins=[Nbin,Nbin] )
h1, xh1 = np.histogram( X, bins=Nbin )
######################################
######################################
fig = plt.figure(  )
gs = gridspec.GridSpec( 3, 2 )
######################################
######################################
ax1 = plt.subplot( gs[:-1,:] )
im = plt.imshow( h2, interpolation='nearest', origin='lower',
                 extent=[xh2[0],xh2[-1],yh2[0],yh2[-1]] )
cb = plt.colorbar( im, ax=ax1 )
plt.xlim( xh1[0], xh1[-1] )
plt.ylim( xh1[0], xh1[-1] )
ax1.tick_params( axis='x', which='both', bottom='on', top='on', labelbottom='off' )
######################################
######################################
ax2 = plt.subplot( gs[-1,:] )
plt.plot( xh1[:-1] + np.diff(xh1)/2., h1 )
plt.xlim( xh1[0], xh1[-1] )
cm = plt.cm.Blues
cb2 = plt.colorbar( ax=ax2 )
ax2.tick_params( axis='x', which='both', bottom='on', top='on', labelbottom='on' )
######################################
######################################
fig.tight_layout()
fig.subplots_adjust(hspace=0.05)
cb2.ax.set_visible(False)

Solution

  • I could imagine that the easiest way to have the second axes directly below the image is to use mpl_toolkits.axes_grid1.make_axes_locatable. This allows to shrink the image at the expense of the newly created subplot and can equally be used to position the colorbar.

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    
    Nbin=20
    
    X = np.random.rand(1000)
    Y = np.random.rand(1000)
    h2, yh2, xh2 = np.histogram2d( Y, X, bins=[Nbin,Nbin] )
    h1, xh1 = np.histogram( X, bins=Nbin )
    
    fig = plt.figure(  )
    
    
    ax1 = plt.subplot(111)
    im = ax1.imshow( h2, interpolation='nearest', origin='lower',
                     extent=[xh2[0],xh2[-1],yh2[0],yh2[-1]] )
    
    plt.xlim( xh1[0], xh1[-1] )
    plt.ylim( xh1[0], xh1[-1] )
    ax1.tick_params( axis='x', which='both', bottom='on', top='on', labelbottom='off' )
    
    
    divider = make_axes_locatable(ax1)
    ax2 = divider.append_axes("bottom", size="50%", pad=0.08)
    cax = divider.append_axes("right", size="5%", pad=0.08)
    cb = plt.colorbar( im, ax=ax1, cax=cax )
    
    #ax2 = plt.subplot( gs[-1,:] )  # , sharex=ax1
    ax2.plot( xh1[:-1] + np.diff(xh1)/2., h1 )
    ax2.set_xlim( xh1[0], xh1[-1] )
    cm = plt.cm.Blues
    
    ax2.tick_params( axis='x', which='both', bottom='on', top='on', labelbottom='on' )
    
    plt.show()
    

    enter image description here