Search code examples
pythonmatplotlibaxesimshow

Adding bar plots at the margins of imshow, keeping bars aligned to cells


I am trying to add a barplot on top of a imshow plot, and another one on the right, with bars aligned to the imshow "cells".

I have tried both using the approach used in this example adding histograms at the margins of a scatterplot), and using make_axes_locatable.

The result that I get is shown in the figure. There are two problems that I can't fix:

  1. the real size of the imshow plot is smaller than the size of the axis in which I am plotting it in, since I want to keep the matrix aspect ratio, and so the actual plot will be strictly contained in the axis
  2. even when this is not a problem (see the top plot), the bars are not aligned with the imshow cells.

enter image description here

This is my code

# from mpl_toolkits.axes_grid1 import make_axes_locatable

plt.style.use('dark_background')

m = np.random.rand(25, 200)

# definitions for the axes
left, width = 0.1, 0.65
bottom, height = 0.1, 0.65
spacing = 0.005

rect0 = [left, bottom, width, height]
rect1 = [left, bottom + height + spacing, width, 0.2]
rect2 = [left + width + spacing, bottom, 0.2, height]

# start with a rectangular Figure
fig = plt.figure(figsize=(20, 8))

ax0 = plt.axes(rect0)
ax0.tick_params(direction='in', top=True, right=True)
ax1 = plt.axes(rect1)
ax1.tick_params(direction='in', labelbottom=False)
ax2 = plt.axes(rect2)
ax2.tick_params(direction='in', labelleft=False)

ax0.matshow(m, norm=matplotlib.colors.LogNorm())

# divider = make_axes_locatable(ax)
# cax = divider.append_axes('right', size='95%', pad=0)
ax1.bar(np.arange(m.shape[1]), np.apply_along_axis(scipy.stats.entropy, 0, m))

# divider = make_axes_locatable(ax)
# cax = divider.append_axes('bottom', size='95%', pad=0)
ax2.barh(np.arange(m.shape[0]), np.apply_along_axis(scipy.stats.entropy, 1, m), orientation='horizontal')
plt.savefig('/data/l989o/a/so.png')
plt.style.use('default')

EDIT Trying to add details to the plot, like axis labels or a colobar, I saw that the general case can be even more complex. I add the code for the more general case of adding other plot elements, along with the code.

One note, I have noticed that I had to invert the barplot on the right, since when using orientation=horizontal, the order of bar was the opposite of the one of the rows of the image.

enter image description here

# from mpl_toolkits.axes_grid1 import make_axes_locatable
import functools

plt.style.use('dark_background')
m = np.random.rand(58, 226) * 20
# definitions for the axes
left, width = 0.1, 0.65
bottom, height = 0.1, 0.65
spacing = 0.005

rect0 = [left, bottom, width, height]
rect1 = [left, bottom + height + spacing, width, 0.2]
rect2 = [left + width + spacing, bottom, 0.2, height]

# start with a rectangular Figure
fig = plt.figure(figsize=(20, 8))

ax0 = plt.axes(rect0)
ax0.tick_params(direction='in', top=True, right=True)
ax1 = plt.axes(rect1)
ax1.tick_params(direction='in', labelbottom=False)
ax2 = plt.axes(rect2)
ax2.tick_params(direction='in', labelleft=False)

t = 10
n = 2
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(None, plt.cm.Set1(range(0, n)), n)
im = ax0.imshow(m > t, cmap=cmap)
ax0.set_xlabel('image')
ax0.set_ylabel('cluster label')

divider = make_axes_locatable(ax0)
cax = divider.append_axes('left', size='1%', pad=1)
cbar = fig.colorbar(im, ticks=range(n), cax=cax)
# cbar.set_lim(-0.5, n - 0.5)
cbar.ax.tick_params(length=0)
cbar.set_ticks([0.25, 0.75])
cbar.set_ticklabels([f'<= {t}', f'> {t}'])
cbar.ax.set_title('# cells')

# divider = make_axes_locatable(ax)
# cax = divider.append_axes('right', size='95%', pad=0)
def sum_treshold(v, threshold):
    return np.sum(v > threshold)
ax1.bar(np.arange(m.shape[1]), np.apply_along_axis(functools.partial(sum_treshold, threshold=t), 0, m))
ax1.set_xlim([0, m.shape[1]])

# divider = make_axes_locatable(ax)
# cax = divider.append_axes('bottom', size='95%', pad=0)
ax2.barh(np.arange(m.shape[0])[::-1], np.apply_along_axis(functools.partial(sum_treshold, threshold=t), 1, m), orientation='horizontal')
ax2.set_ylim([0, m.shape[0]])
plt.savefig('/data/l989o/a/so.png')
plt.style.use('default')

EDIT 2 Here is an example of what the final output should look like. To obtain that I have done a wild binary search and set hard coded coordinates (which of course work only for the specific data matrix I have and not in general).

enter image description here


Solution

  • I am not sure if this will get you the exact layout you want, but maybe some bits here will be helpful.

    This answer uses gridspec to define the relative ratios of the subplots and inset_axes with transform to add the colorbar. The answer here by @Marc is a nice, simple example of how to use gridspec, if that part is confusing.

    import matplotlib.pyplot as plt
    %matplotlib inline
    import matplotlib.gridspec as gridspec
    import matplotlib.colors as mcolors
    import numpy as np
    
    m = np.random.rand(58, 226) * 20
    
    fig = plt.figure(figsize=(20,8), constrained_layout=True)
    gs = fig.add_gridspec(2, 3)
    ax1 = fig.add_subplot(gs[0, 0:2])
    ax2 = fig.add_subplot(gs[1, 0:2]) 
    ## can add these if you need to share axes:, sharex = ax1, sharey = ax1)
    ax3 = fig.add_subplot(gs[:, -1])
    
    ax1.bar(np.arange(m.shape[1]), np.arange(m.shape[1]))
    vals = ax2.imshow(np.random.random((20,10)), cmap='rainbow', aspect='auto') 
    ## aspect = 'auto' follows the established gridpec space
    ## default for imshow is equal axis
    ax3.barh(np.arange(m.shape[1]), np.arange(m.shape[1]))
    
    cbax2=ax2.inset_axes([1.05,0,0.03,1], transform=ax2.transAxes)
    ## the inset axes inputs are x,y,width,height
    ## the transform "anchors" these relative to the ax2 axis
    ## so here we are saying start at 5% past the ax2 width; start at the bottom of ax2 (y=0);
    ### make the inset axis 3% as wide as the ax2 axis; and make it 100% as tall as the ax2 axis
    cbar2=fig.colorbar(vals, cax=cbax2, format = '%1.2g', orientation='vertical')
    

    enter image description here

    Updated Based on Comments: Is this closer to your needed answer?

    import matplotlib.pyplot as plt
    %matplotlib inline
    import matplotlib.gridspec as gridspec
    import matplotlib.colors as mcolors
    import numpy as np
    import functools
    
    def sum_treshold(v, threshold):
        return np.sum(v > threshold)
    
    m = np.random.rand(58, 226) * 20
    t = 10
    n = 2
    cmap = mcolors.LinearSegmentedColormap.from_list(None, plt.cm.Set1(range(0, n)), n)
    
    fig = plt.figure(figsize=(20,8), constrained_layout=True)
    gs = fig.add_gridspec(2, 3)
    ax1 = fig.add_subplot(gs[0, 0:2])
    ax2 = fig.add_subplot(gs[1, 0:2], sharex = ax1, sharey = ax1)
    ax3 = fig.add_subplot(gs[1, -1], sharey = ax1)
    
    ax1.bar(np.arange(m.shape[1]), 
            np.apply_along_axis(functools.partial(sum_treshold, threshold=t), 0, m))
    ax1.set_xlim([0, m.shape[1]])
    im = ax2.imshow(m > t, cmap=cmap)
    ax3.barh(np.arange(m.shape[0])[::-1], 
             np.apply_along_axis(functools.partial(sum_treshold, threshold=t), 1, m), 
             orientation='horizontal')
    ax3.set_ylim([0, m.shape[0]])
    
    
    cbax2=ax2.inset_axes([-0.10,0,0.03,1], transform=ax2.transAxes)
    cbar2=fig.colorbar(im, cax=cbax2, format = '%1.2g', orientation='vertical')
    

    enter image description here