Search code examples
pythonmatplotlibcolorbar

Custom Spacing for colors in discrete colorbar


Essentially, I want to make my discrete, binary, colorbar in python matplotlib/seaborn to have custom spacing, so that one color takes up more of the color bar than the other.

I am using seaborn heatmap to plot some binary data I have. Each row contains p different items which were labeled by my binary classifier. Four out of eleven rows belong to Class1 and the other 7 belong to Class0. I would like to have the colorbar help illustrate that breakdown, so that 4/11 of the colorbar is colored the same as Class1.

# make colormap
yellow = (249/255, 231/255, 85/255)
blue = (62/255,11/255, 81/255)
color_list = [yellow, blue]
cmap = ListedColormap(color_list)
    
# plot data
h = sns.heatmap(binary_preds, cmap=cmap, cbar_kws = dict(use_gridspec=False,location="left"))

for i in range(len(binary_preds) + 1):
    h.axhline(i, color='white', lw=5)

colorbar = h.collections[0].colorbar
colorbar.set_ticks([.25,.75])
colorbar.set_ticklabels(['Class0', 'Class1'])

## code I would like:
# colorbar.set_spacing([0.37, 63])

Resulting colorbar:

What my code generates:

How I'd like it to be (manually adjusted colorbar spacing):

What I want it to generate (manually adjusted colorbar spacing)


Solution

  • The following approach uses a BoundaryNorm and proportional spacing for the colorbar:

    from matplotlib import pyplot as plt
    from matplotlib.colors import BoundaryNorm, ListedColormap
    import numpy as np
    import seaborn as sns
    
    # make colormap
    yellow = (249 / 255, 231 / 255, 85 / 255)
    blue = (62 / 255, 11 / 255, 81 / 255)
    color_list = [yellow, blue]
    cmap = ListedColormap(color_list)
    # define a boundary norm
    proportion_class0 = 0.67  # proportion for class0
    norm = BoundaryNorm([0, proportion_class0, 1], 2)
    
    binary_preds = np.random.choice([False, True], size=(10, 15), p=[proportion_class0, 1 - proportion_class0])
    ax = sns.heatmap(binary_preds, cmap=cmap, norm=norm,
                     cbar_kws=dict(use_gridspec=False, location="left", spacing="proportional"))
    
    for i in range(len(binary_preds) + 1):
         ax.axhline(i, color='white', lw=5)
    
    colorbar = ax.collections[0].colorbar
    colorbar.set_ticks([proportion_class0 / 2, (1 + proportion_class0) / 2])
    colorbar.set_ticklabels(['Class0', 'Class1'])
    
    plt.show()
    

    BoundaryNorm to change colorbar spacing