Search code examples
pythonpython-3.xtensorflowseabornheatmap

Seaborn Heatmap - Display the heatmap only if values are above given threshold


The below python code displays sentence similarity, it uses Universal Sentence Encoder to achieve the same.

from absl import logging

import tensorflow as tf

import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns

module_url = "https://tfhub.dev/google/universal-sentence-encoder/4" 
model = hub.load(module_url)
print ("module %s loaded" % module_url)
def embed(input):
  return model(input)


def plot_similarity(labels, features, rotation):
  corr = np.inner(features, features)
  print(corr)
  sns.set(font_scale=2.4)
  plt.subplots(figsize=(40,30))
  g = sns.heatmap(
      corr,
      xticklabels=labels,
      yticklabels=labels,
      vmin=0,
      vmax=1,
      cmap="YlGnBu",linewidths=1.0)
  g.set_xticklabels(labels, rotation=rotation)
  g.set_title("Semantic Textual Similarity")

def run_and_plot(messages_):
  message_embeddings_ = embed(messages_)
  plot_similarity(messages_, message_embeddings_, 90)


messages = [
"I want to know my savings account balance",
"Show my bank balance",
"Show me my account",
"What is my bank balance",
"Please Show my bank balance"    

]

run_and_plot(messages)

The output is displayed as heatmap as shown below, also printing the values enter image description here

I want to only focus on the sentences that seems quite similar, however the currently heatmap displays all the values.

So

  1. Is there a way I can view heatmap with only values whose ranges is more than 0.6 and less than 0.999?

  2. Is it possible to print the matching value pairs, which lie under given ranges, i.e. 0.6 and 0.99? Thanks, Rohit


Solution

  • Following your question update, here is a revised version. Obviously, in a grid, one cannot delete an individual cell. But we can reduce the heatmap substantially to show only relevant value pairs. This effect will be less visible the more randomly scattered significant values exist in the heatmap.

    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    from copy import copy
    import seaborn as sns
    
    #semi-random data generation 
    labels = list("ABCDE")
    np.random.seed(123)
    df = pd.DataFrame(np.random.randint(1, 100, (20, 5)))
    df.columns = labels
    df.A = df.B - df.D
    df.C = df.B + df.A
    df.E = df.A + df.C
    
    #your correlation array
    corr = df.corr().to_numpy()
    print(corr)
    
    #conditions for filtering 0.6<=r<=0.9
    val_min = 0.6
    val_max = 0.99
    
    #plotting starts here
    sns.set(font_scale=2.4)
    #two axis objects just for comparison purposes
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,8))
    
    #define the colormap with clipping values
    my_cmap = copy(plt.cm.YlGnBu)
    my_cmap.set_over("white")
    my_cmap.set_under("white")
    
    #ax1 - full set of conditions as in the initial version 
    g1 = sns.heatmap(corr,
        xticklabels=labels,
        yticklabels=labels,
        vmin=val_min,
        vmax=val_max,
        cmap=my_cmap,
        linewidths=1.0,
        linecolor="grey",
        ax=ax1)
    
    g1.set_title("Entire heatmap")
    
    #ax2 - remove empty rows/columns
    # use only lower triangle
    corr = np.tril(corr)
    
    #delete columns where all elements do not fulfill the conditions
    ind_x,  = np.where(np.all(np.logical_or(corr<val_min, corr>val_max), axis=0))
    corr = np.delete(corr, ind_x, 1)
    #update x labels
    map_labels_x = [item for i, item in enumerate(labels) if i not in ind_x]
        
    #now the same for rows 
    ind_y, = np.where(np.all(np.logical_or(corr<val_min, corr>val_max), axis=1))
    corr = np.delete(corr, ind_y, 0)
    #update y labels
    map_labels_y = [item for i, item in enumerate(labels) if i not in ind_y]
    
    #plot heatmap
    g2 = sns.heatmap(corr,
        xticklabels=map_labels_x,
        yticklabels=map_labels_y,
        vmin=val_min,
        vmax=val_max,
        cmap=my_cmap,
        linewidths=1.0,
        linecolor="grey", ax=ax2)
    
    g2.set_title("Reduced heatmap")
    
    plt.show()
    

    Sample output:

    enter image description here

    Left, original approach showing all elements of the heatmap. Right, only relevant pairs are kept. The question (and therefore the code) excludes significant negative correlations, e.g., -0.95. If this is not intended, np.abs() should be used.

    Initial answer
    I am surprised that nobody has provided a self-contained solution yet, so here is one:

    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    from copy import copy
    import seaborn as sns
    
    labels = list("ABCDE")
    #semi-random data
    np.random.seed(123)
    df = pd.DataFrame(np.random.randint(1, 100, (20, 5)))
    df.columns = labels
    df.A = df.B - df.D
    df.E = df.A + df.C
    
    corr = df.corr()
    sns.set(font_scale=2.4)
    plt.subplots(figsize=(10,8))
    
    #define the cmap with clipping values
    my_cmap = copy(plt.cm.YlGnBu)
    my_cmap.set_over("white")
    my_cmap.set_under("white")
    
    g = sns.heatmap(corr,
        xticklabels=labels,
        yticklabels=labels,
        vmin=0.5,
        vmax=0.9,
        cmap=my_cmap,
        linewidths=1.0,
        linecolor="grey")
    
    g.set_xticklabels(labels, rotation=60)
    g.set_title("Important!")
    
    plt.show()
    

    Sample output: enter image description here