Search code examples
pythonmatplotlibseaborn

Seaborn Confusion Matrix - Set Data for Colorbar


The follow code works, however, I want the colorbar and color to represent the %, not the count. I can't see to find a way to specify what data the colorbar should use. Anyone know how to do this?

truth_labels = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
pred_labels = [1, 1, 1, 1, 0, 1, 1, 0, 0, 0]
TITLE_FONT_SIZE = {"size":"40"}
LABEL_FONT_SIZE = {"size":"40"}
LABEL_SIZE = 40


conf_matrix = confusion_matrix(pred_labels, truth_labels, labels=[1, 0])
group_counts = ["{0:0.0f}".format(value) for value in conf_matrix.flatten()]
group_normalized_percentages = (conf_matrix / np.sum(conf_matrix, axis=0, keepdims=True)).ravel()
group_normalized_percentages = ["{0:.2%}".format(value) for value in group_normalized_percentages]
cell_labels = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts,group_normalized_percentages)]
cell_labels = np.asarray(cell_labels).reshape(2, 2)
sns.set(font_scale=4.0)
sns.heatmap(conf_matrix, annot=cell_labels, cmap="Blues", fmt="", ax=ax)

# Titles, axis labels, etc.
title = "Confusion Matrix\n"

ax.set_title(title, fontdict=TITLE_FONT_SIZE)
ax.set_xlabel("Actual", fontdict=LABEL_FONT_SIZE)
ax.set_ylabel("Predicted", fontdict=LABEL_FONT_SIZE)
ax.tick_params(axis="both", which="major", labelsize=LABEL_SIZE)
ax.set_xticklabels(["1", "0"])
ax.set_yticklabels(["1", "0"], rotation=90, va="center")

enter image description here


Solution

  • A simple solution without changing your code too much, would be to pass the confusion matrix represented as percentage to your heatmap directly and keep your original annotations:

    import seaborn as sns
    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.metrics import confusion_matrix
    
    fig, ax = plt.subplots(figsize=(10, 10))
    truth_labels = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
    pred_labels = [1, 1, 1, 1, 0, 1, 1, 0, 0, 0]
    TITLE_FONT_SIZE = {"size":"40"}
    LABEL_FONT_SIZE = {"size":"40"}
    LABEL_SIZE = 40
    
    conf_matrix = confusion_matrix(pred_labels, truth_labels, labels=[1, 0])
    group_counts = ["{0:0.0f}".format(value) for value in conf_matrix.flatten()]
    group_normalized_percentages = conf_matrix / np.sum(conf_matrix, axis=0, keepdims=True)
    group_normalized_percentages_2=["{0:.2%}".format(value) for value in group_normalized_percentages.ravel()]
    cell_labels = [f"{v1}\n{v2}" for v1, v2 in zip(group_counts,group_normalized_percentages_2)]
    cell_labels = np.asarray(cell_labels).reshape(2, 2)
    
    sns.set(font_scale=4.0)
    sns.heatmap(100.0*group_normalized_percentages, annot=cell_labels, cmap="Blues", fmt="", ax=ax,cbar_kws={'format': '%.0f%%'})
    
    # Titles, axis labels, etc.
    title = "Confusion Matrix\n"
    
    ax.set_title(title, fontdict=TITLE_FONT_SIZE)
    ax.set_xlabel("Actual", fontdict=LABEL_FONT_SIZE)
    ax.set_ylabel("Predicted", fontdict=LABEL_FONT_SIZE)
    ax.tick_params(axis="both", which="major", labelsize=LABEL_SIZE)
    ax.set_xticklabels(["1", "0"])
    ax.set_yticklabels(["1", "0"], rotation=90, va="center")
    

    enter image description here