Search code examples
seabornheatmapcolorbar

Seaborn heatmap colobar: how to assure the correct order of classes and correct colors displayed


I have a dataframe with results from a certain calculation that I would like to plot as a seaborn heatmap with a color bar. I'm using the following code to achieve that (mostly taken from here: enter link description here):

# input data
results = [['equal','equal','smaller','smaller or equal','greater or equal'],   
           ['equal','equal','smaller','smaller','greater or equal'],                                      
           ['greater','equal','smaller or equal','smaller','smaller'],
           ['equal','smaller or equal','greater or equal','greater or equal','equal'],
           ['equal','equal','smaller','equal','equal']]

index = ['axc', 'org', 'cf5', 'cm1', 'ext']
columns = ['axc', 'org', 'cf5', 'cm1', 'ext']

# create a dataframe
res_df = pd.DataFrame(results, columns, index) 

value_to_int = {j:i for i,j in enumerate(['greater','greater or equal','equal','smaller or equal','smaller'])}

n = len(value_to_int)     

# discrete colormap (n samples from a given cmap)
cmap = sns.color_palette("viridis", n) 
ax = sns.heatmap(res_df.replace(value_to_int), cmap=cmap) 

# modify colorbar:
colorbar = ax.collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin 
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(value_to_int.keys()))                                          
plt.show()

enter image description here

It works like a charm most of the time, but the problem arises if one of the classes from index list is not present. To demonstrate, if you change the data frame like this:

results_changed = [['equal','equal','smaller','smaller or equal','greater or equal'],
              ['equal','equal','smaller','smaller','greater or equal'],
              ['greater or equal','equal','smaller or equal','smaller','smaller'],
              ['equal','smaller or equal','greater or equal','greater or equal','equal'],
              ['equal','equal','smaller','equal','equal']]

index = ['axc', 'org', 'cf5', 'cm1', 'ext']
columns = ['axc', 'org', 'cf5', 'cm1', 'ext']

# create a dataframe
res_df = pd.DataFrame(results_changed, columns, index) 

value_to_int = {j:i for i,j in enumerate(['greater','greater or equal','equal','smaller or equal','smaller'])}

n = len(value_to_int)  

# discrete colormap (n samples from a given cmap)
cmap = sns.color_palette("viridis", n) 
ax = sns.heatmap(res_df.replace(value_to_int), cmap=cmap) 

# modify colorbar:
colorbar = ax.collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin 
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(value_to_int.keys()))                                          
plt.show()  

And proceed with plotting, the resulting heatmap will assign wrong colors to the classes--since there is no case 'greater' now, it will "shift" the palette and equal will not be assigned the correct color as before.

enter image description here

I've tried to remedy the problem by changing this line in the code:

value_to_int = {j:i for i,j in enumerate(pd.unique(res_df.values.ravel()))}

While it fixes the color assignment problem, it creates another problem because the color bar will have messed up ordering of colors (which I would like to avoid).

enter image description here

Could anyone suggest how to fix this? I'd appreciate any suggestions.


Solution

  • The best way to ensure comparability over different conditions is to clamp the color bar always to the same levels:

    import pandas as pd
    from matplotlib import pyplot as plt
    import seaborn as sns
    
    results_changed = [['equal','equal','smaller','smaller or equal','greater or equal'],
                  ['equal','equal','smaller','smaller','greater or equal'],
                  ['greater or equal','equal','smaller or equal','smaller','smaller'],
                  ['equal','smaller or equal','greater or equal','greater or equal','equal'],
                  ['equal','equal','smaller','equal','equal']]
    
    index = ['axc', 'org', 'cf5', 'cm1', 'ext']
    columns = ['axc', 'org', 'cf5', 'cm1', 'ext']
    
    # create a dataframe
    res_df = pd.DataFrame(results_changed, columns, index) 
    
    #construct dictionary from ordered list
    category_order = ['greater', 'greater or equal', 'equal', 'smaller or equal', 'smaller']    
    value_to_int = {j:i for i,j in enumerate(category_order)}    
    n = len(value_to_int)  
    
    # discrete colormap (n samples from a given cmap)
    cmap = sns.color_palette("viridis", n) 
    ax = sns.heatmap(res_df.replace(value_to_int), cmap=cmap, vmin=0, vmax=n) 
    
    #modify colorbar:
    colorbar = ax.collections[0].colorbar 
    colorbar.set_ticks([0.5 + i for i in range(n)])
    colorbar.set_ticklabels(category_order)                                          
    plt.show()  
    

    Sample output:

    enter image description here

    If you wanted to show only actually existing colors in the colorbar, you could prefilter the list of existing categories but this will change the color scheme for different input arrays, making them difficult to compare.

    import pandas as pd
    from matplotlib import pyplot as plt
    import seaborn as sns
    import numpy as np
    
    results_changed = [['equal','equal','smaller','smaller or equal','greater'],
                  ['equal','equal','smaller','smaller','greater'],
                  ['greater','equal','smaller','smaller','smaller'],
                  ['equal','smaller','greater','greater','equal'],
                  ['equal','equal','smaller','equal','equal']]
    
    index = ['axc', 'org', 'cf5', 'cm1', 'ext']
    columns = ['axc', 'org', 'cf5', 'cm1', 'ext']
    
    # create a dataframe
    res_df = pd.DataFrame(results_changed, columns, index) 
    
    unique_results = np.unique(results_changed)
    unique_categories = [cat for cat in ['greater','greater or equal','equal','smaller or equal','smaller'] if cat in unique_results]
    
    value_to_int = {j:i for i,j in enumerate(unique_categories)}
    
    n = len(value_to_int)  
    
    # discrete colormap (n samples from a given cmap)
    cmap = sns.color_palette("viridis", n) 
    ax = sns.heatmap(res_df.replace(value_to_int), cmap=cmap) 
    
    #modify colorbar:
    colorbar = ax.collections[0].colorbar 
    r = colorbar.vmax - colorbar.vmin 
    colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
    colorbar.set_ticklabels(unique_categories)
    plt.show()  
    

    Sample output:

    enter image description here