Search code examples
pythonmachine-learningseabornheatmapconfusion-matrix

Color confusion matrix heatmap using percentage of correctness


This is an example of code to plot a confusione matrix for a multi classification problem.

cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [1,  0, 4]])

labels = ['col1','col2','col3']
df_confusion = pd.DataFrame(cf_matrix, index = labels, columns=labels)
df_confusion['TOTAL'] = df_confusion.sum(axis=1)
df_confusion.loc['TOTAL']= df_confusion.sum()

plt.figure(figsize=(24, 10))

sns.set(font_scale = 1.5)
ax = sns.heatmap(df_confusion, annot=True, cmap='Blues', fmt="d")    
ax.set_title('Confusion Matrix\n\n',size=22)
ax.set_xlabel('\nPredicted Values',size=20)
ax.set_ylabel('Actual Values ', size=20)

plt.show()

This is the result: enter image description here

How can I change the colorbar so that the color is not related to the number of elements, but instead is based on the percentage of elements for each cell divided by the total real elements for that class (row). For example, the third class col3 in this case will have the highest color because it has 4/5 = 80% of correct prediction respect to col1 and col2 that have respectively: 50/90 = 55% and 43/82 = 52% of correct predicions.


Solution

  • Since cmap parameter uses data to apply the gradients, you need to change data to percentages, and then use annot parameter to overwrite the values with the actual numbers.

    So, I think you want something like below. N.B. I've set the percentages for df_percentages.TOTAL to 0 below; otherwise the column TOTAL will obviously turn dark blue completely.

    Anyway, now that you know the logic, I trust you'll know how to adjust the values of df_percentages to your liking.

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    cf_matrix = np.array([[50, 2, 38],
                          [7, 43, 32],
                          [1,  0, 4]])
    
    labels = ['col1','col2','col3']
    df_confusion = pd.DataFrame(cf_matrix, index = labels, columns=labels)
    df_confusion['TOTAL'] = df_confusion.sum(axis=1)
    df_confusion.loc['TOTAL']= df_confusion.sum()
    
    # get percentages
    df_percentages = df_confusion.div(df_confusion.TOTAL, axis=0)
    df_percentages.TOTAL = 0
    
    # =============================================================================
    #            col1      col2      col3  TOTAL
    # col1   0.555556  0.022222  0.422222      0
    # col2   0.085366  0.524390  0.390244      0
    # col3   0.200000  0.000000  0.800000      0
    # TOTAL  0.327684  0.254237  0.418079      0
    # =============================================================================
    
    plt.figure(figsize=(24, 10))
                            
    sns.set(font_scale = 1.5)
    
    # cmap using data for color, taking values from annot
    ax = sns.heatmap(data=df_percentages, annot=df_confusion, cmap='Blues', fmt="d",
                     cbar_kws={'label': 'percentages'})  
      
    ax.set_title('Confusion Matrix\n\n',size=22)
    ax.set_xlabel('\nPredicted Values',size=20)
    ax.set_ylabel('Actual Values ', size=20)
    
    plt.show()
    

    Result:

    plot heatmap