Search code examples
pythonmatrixseaborncorrelationanalysis

Seaborn Correlation Matrix with p values with Python


I have a diagonal correlation matrix produced in seaborn. I would like to mask out the ones that have a p-value greater than 0.05.

Here's what I've got https://i.sstatic.net/16Rky.jpg

sns.set(style="white")
corr = result.corr()
print corr

mask = np.zeros_like(corr, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
f, ax = plt.subplots(figsize=(11, 9))
sns_plot = sns.heatmap(result.corr(),mask=mask, annot=True, center=0, square=True, fmt=".1f", linewidths=.5, cmap="Greens")

Would greatly appreciate any help with this. Many thanks


Solution

  • For the sake of completeness, here is a solution that uses scipy.stats.pearsonr (docs) to create a matrix of p-values. Following creating a boolean mask to pass to seaborn (or to additionally combine with numpy np.triu to hide upper triangle of correlations)

    def corr_sig(df=None):
        p_matrix = np.zeros(shape=(df.shape[1],df.shape[1]))
        for col in df.columns:
            for col2 in df.drop(col,axis=1).columns:
                _ , p = stats.pearsonr(df[col],df[col2])
                p_matrix[df.columns.to_list().index(col),df.columns.to_list().index(col2)] = p
        return p_matrix
    
    p_values = corr_sig(df)
    mask = np.invert(np.tril(p_values<0.05))
    # note seaborn will hide correlation were the boolean value is True in the mask
    


    Complete Process with Examples

    First off create some sample data (3 correlated variables; 3 uncorrelated ones):

    import pandas as pd
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    from scipy import stats
    
    # Simulate 3  correlated variables
    num_samples = 100
    mu = np.array([5.0, 0.0, 10.0])
    # The desired covariance matrix.
    r = np.array([
            [  3.40, -2.75, -2.00],
            [ -2.75,  5.50,  1.50],
            [ -2.00,  1.50,  1.25]
        ])
    y = np.random.multivariate_normal(mu, r, size=num_samples)
    df = pd.DataFrame(y)
    df.columns = ["Correlated1","Correlated2","Correlated3"]
    
    # Create two random variables 
    for i in range(2):
        df.loc[:,f"Uncorrelated{i}"] = np.random.randint(-2000,2000,len(df))
    
    # To make sure that they are uncorrelated - add also a nearly invariant variables
    df.loc[:,"Near Invariant"] = np.random.randint(-99,-95,num_samples)
    

    Plotting function for convenience
    Mainly for cosmetics of the heatmap.

    def plot_cor_matrix(corr, mask=None):
        f, ax = plt.subplots(figsize=(11, 9))
        sns.heatmap(corr, ax=ax,
                    mask=mask,
                    # cosmetics
                    annot=True, vmin=-1, vmax=1, center=0,
                    cmap='coolwarm', linewidths=2, linecolor='black', cbar_kws={'orientation': 'horizontal'})
    

    Corr.-Plot of Example Data with all Correlation
    To give you an understanding how the correlations would look like in this exemplary correlation matrix without filtering for significant Correlations(p-Values < .05).

    # Plotting without significance filtering
    corr = df.corr()
    mask = np.triu(corr)
    plot_cor_matrix(corr,mask)
    plt.show()
    

    enter image description here

    Corr.Plot of Example Data with only Sig. Correlations Finally plotting with only significant p-value correlation (alpha < .05)

    # Plotting with significance filter
    corr = df.corr()                            # get correlation
    p_values = corr_sig(df)                     # get p-Value
    mask = np.invert(np.tril(p_values<0.05))    # mask - only get significant corr
    plot_cor_matrix(corr,mask)  
    

    enter image description here

    Conclusion

    While in the first correlation-matrix there are some correlation coefficients (r) that are >.05 (filtering as suggested in the comments of the OP), that doesn't imply that the p-value is significant. Thus, it is important to distinguish the p value from the correlation coefficient r.

    I hope that this answer will be in future helpful for other searching a way to plot significant correlations with a sns.heatmap