Search code examples
pythonseabornheatmapsignificance

Colour statistically non-significant values in seaborn heatmap with a different colour


I had this problem that I wanted to somehow highlight statistically not significant correlations in seaborn's heatmap. I knew I could hide them with the following code:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr

planets = sns.load_dataset('planets')

# get the p value for pearson coefficient, subtract 1 on the diagonal
pvals = planets.corr(method=lambda x, y: pearsonr(x, y)[1]) - np.eye(*planets.corr().shape)
# set the significance threshold
psig = 0.05

plt.figure(figsize=(6,6))

sns.heatmap(planets.corr()[pvals<psig], annot=True, square=True)

However, that creates these weird white holes and I would like to keep the values and the information, I would just like to emphasise it with another colour.

holes


Solution

  • The way how to solve it was a) to use the same thresholding for another heatmap plotted to the same axes; and b) to add a patch to the legend so that it also has a nice label:

    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from scipy.stats import pearsonr
    import matplotlib.patches as mpatches
    
    planets = sns.load_dataset('planets')
    
    # get the p value for pearson coefficient, subtract 1 on the diagonal
    pvals = planets.corr(method=lambda x, y: pearsonr(x, y)[1]) - np.eye(*planets.corr().shape)
    # set the significance threshold
    psig = 0.05
    
    plt.figure(figsize=(6,6))
    
    sns.heatmap(planets.corr()[pvals<psig], annot=True, square=True)
    
    # add another heatmap with colouring the non-significant cells
    sns.heatmap(planets.corr()[pvals>=psig], annot=True, square=True, cbar=False, 
                cmap=sns.color_palette("Greys", n_colors=1, desat=1))
    
    
    
    # add a label for the colour
    # https://stackoverflow.com/questions/44098362/using-mpatches-patch-for-a-custom-legend
    colors = [sns.color_palette("Greys", n_colors=1, desat=1)[0]]
    texts = [f"n.s. (at {psig})"]
    patches = [ mpatches.Patch(color=colors[i], label="{:s}".format(texts[i]) ) for i in range(len(texts)) ]
    plt.legend(handles=patches, bbox_to_anchor=(.85, 1.05), loc='center')
    
    

    Furthermore, one would be able to even use multiple conditions for the masking and different significance levels.

    done