Search code examples
pythonmatplotlibseabornlegendheatmap

How to add a legend to a heatmap


I am using a small variation of this really nice code to plot a heatmap.

import matplotlib
import seaborn as sns
import numpy as np
from matplotlib.colors import ListedColormap

np.random.seed(7)
A = np.random.randint(0,100, size=(20,20))
mask_array = np.zeros((20, 20), dtype=bool)
mask_array[:, :5] = True
# cmap = matplotlib.colormaps["viridis"]
cmap = matplotlib.cm.get_cmap('viridis').copy()


# Set the under color to white
cmap.set_under("white")

# Set the over color to white
cmap.set_over("black")

# Set the background color

g = sns.heatmap(A, vmin=10, vmax=90, cmap=cmap, mask=mask_array)
# Set color of masked region
g.set_facecolor('lightgrey')

cbar_ax = g.figure.axes[-1]

for spine in cbar_ax.spines.values():
    spine.set(visible=True)
    
special_data = np.ma.masked_where(A==20, A)
sns.heatmap(special_data, cmap=ListedColormap((1.0000, 0.2716, 0.0000)), 
            mask=(special_data != 1), cbar=False)

The result looks like:

heatmap

The squares that had the value 20 and so are now colored with RGB (1.0000, 0.2716, 0.0000) indicate that the experiment was broken. I would like to add a legend that has a square of that color and the word "broken" next to it. It will have to be outside the heatmap so as not to obscure it. How can I do that?


Solution

  • You can make your own rectangle with the color you want and feed it into the legend method. If you want to move the legend around then you can use the loc and bbox_to_anchor arguments- see the legend guide for more info on these https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html

    from matplotlib.patches import Rectangle
    
    # ...
    
    rect = Rectangle((0, 0), 0, 0, color=(1.0000, 0.2716, 0.0000))
    g.legend(handles=[rect], labels=['Broken'], loc='lower right', bbox_to_anchor=(1, 1))
    

    enter image description here

    Inserting all of the code you already have...

    import matplotlib
    import seaborn as sns
    import numpy as np
    from matplotlib.colors import ListedColormap
    from matplotlib.pyplot import show
    from matplotlib.patches import Rectangle
    
    np.random.seed(7)
    A = np.random.randint(0,100, size=(20,20))
    mask_array = np.zeros((20, 20), dtype=bool)
    mask_array[:, :5] = True
    # cmap = matplotlib.colormaps["viridis"]
    cmap = matplotlib.cm.get_cmap('viridis').copy()
    
    
    # Set the under color to white
    cmap.set_under("white")
    
    # Set the over color to white
    cmap.set_over("black")
    
    # Set the background color
    
    g = sns.heatmap(A, vmin=10, vmax=90, cmap=cmap, mask=mask_array)
    # Set color of masked region
    g.set_facecolor('lightgrey')
    
    cbar_ax = g.figure.axes[-1]
    
    for spine in cbar_ax.spines.values():
        spine.set(visible=True)
    
    special_data = np.ma.masked_where(A==20, A)
    sns.heatmap(special_data, cmap=ListedColormap((1.0000, 0.2716, 0.0000)),
                mask=(special_data != 1), cbar=False, ax=g)
    
    rect = Rectangle((0, 0), 0, 0, color=(1.0000, 0.2716, 0.0000))
    g.legend(handles=[rect], labels=['Broken'], loc='lower right', bbox_to_anchor=(1, 1))
    
    show()