Search code examples
pythonmatplotlibscikit-learnjupyter-notebookseaborn

How to plot accuracy, precision and recall in confusion matrix plot using seaborn


I want to plot precision, and recall and accuracy of my prediction in confusion matrix using seaborn. The figure i want to achieve is.

This image:

This image.

So I tried to fit my model with data, and got the following confusion matrix.

Image_confusion matrix:

Image_confusion matrix.

Confusion Matrix:

col_0          0   1   2   3
Damage state                
0             65   0   0   0
1              0  51   0   0
2              0   0  31   1
3              0   0   0  52

I print the classification report and got these values.

    precision    recall  f1-score   support

           0       1.00      1.00      1.00        65
           1       1.00      1.00      1.00        51
           2       1.00      0.97      0.98        32
           3       0.98      1.00      0.99        52

    accuracy                           0.99       200
   macro avg       1.00      0.99      0.99       200
weighted avg       1.00      0.99      0.99       200
print(classification_report(Y_train, y_train_pred))

Then i plotted my confusion matrix using seaborn.

#Import confusion matrix library
from sklearn.metrics import confusion_matrix

c_mat = confusion_matrix(Y_train, y_train_pred)
group_counts = ["{0:0.0f}".format(value) for value in
                c_mat.flatten()]
group_percentages = ["{0:.2%}".format(value) for value in
                     c_mat.flatten() / np.sum(c_mat)]
labels = [f"{v1}\n{v2}" for v1, v2 in
          zip(group_counts,group_percentages)]
labels = np.asarray(labels).reshape(4,4)
disp = sns.heatmap(c_mat, annot=labels, cmap='Reds', fmt='', xticklabels=["D1","D2","D3","D4"]
                   , yticklabels=["D1","D2","D3", "D4"], cbar=False)
disp.plot()
plt.title('Predicted Damage State', fontweight='bold')
plt.tick_params(labeltop=True, labelbottom=False)
plt.ylabel('Actual Damage State', fontweight='bold')
plt.show()

And i achieved this:

Confusion matrix achieved

Is there any way i can achieve the desired figure as shown above.


Solution

  • You can draw a second heatmap onto the same subplot (ax). That heatmap starts from a matrix that has one row and one column more than the given confusion matrix. The part that overlaps with to the confusion matrix can be masked away.

    import matplotlib.pyplot as plt
    from matplotlib. Colors import ListedColormap
    import seaborn as sns
    import numpy as np
    
    c_mat = np.array([[65, 0, 0, 0],
                      [0, 51, 0, 0],
                      [0, 0, 31, 1],
                      [0, 0, 0, 52]])
    total = np.sum(c_mat)
    labels = [[f"{val:0.0f}\n{val / total:.2%}" for val in row] for row in c_mat]
    states = ["D1", "D2", "D3", "D4"]
    
    ax = sns.heatmap(c_mat, annot=labels, cmap='Reds', fmt='',
                     xticklabels=states, yticklabels=states, cbar=False)
    ax.set_title('Predicted Damage State', fontweight='bold')
    ax.tick_params(labeltop=True, labelbottom=False, length=0)
    ax.set_ylabel('Actual Damage State', fontweight='bold')
    
    # matrix for the extra column and row
    f_mat = np.zeros((c_mat.shape[0] + 1, c_mat.shape[1] + 1))
    f_mat[:-1, -1] = np.diag(c_mat) / np.sum(c_mat, axis=1)  # fill recall column
    f_mat[-1, :-1] = np.diag(c_mat) / np.sum(c_mat, axis=0)  # fill precision row
    f_mat[-1, -1] = np.trace(c_mat) / np.sum(c_mat)  # accuracy
    
    f_mask = np.ones_like(f_mat)  # puts 1 for masked elements
    f_mask[:, -1] = 0  # last column will be unmasked
    f_mask[-1, :] = 0  # last row will be unmasked
    
    # matrix for coloring the heatmap
    # only last row and column will be used due to masking
    f_color = np.ones_like(f_mat)
    f_color[-1, -1] = 0  # lower right gets different color
    
    # matrix of annotations, only last row and column will be used
    f_annot = [[f"{val:0.2%}" for val in row] for row in f_mat]
    f_annot[-1][-1] = "Acc.:\n" + f_annot[-1][-1]
    
    sns.heatmap(f_color, mask=f_mask, annot=f_annot, fmt='',
                xticklabels=names + ["Recall"],
                yticklabels=names + ["Precision"],
                cmap=ListedColormap(['skyblue', 'lightgrey']), cbar=False, ax=ax)
    plt.show()
    

    seaborn confusion matrix with extra column for recall and row for precision