Search code examples
pythonpandasmatplotlibheatmapconfusion-matrix

Plot a Confusion Matrix in Python using a Dataframe of Strings


I'm doing a 10-fold validation and I need to see how the accuracy of each class changes. I managed to create a DataFrame like this:

Snippet:

chars = []
for i in range(0, int(classes) + 1):
    row = []
    for j in range(0, int(classes) + 1):
        row.append(str(round(means[i, j], 3)) + " +/- " + str(round(stds[i, j], 3)))
    chars.append(row)

con_mat_df = pd.DataFrame(chars, index=classes_list, columns=classes_list)
           0                1   ...               14               15
0    100.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
1   0.49 +/- 0.703  98.53 +/- 1.416  ...      0.0 +/- 0.0      0.0 +/- 0.0
2      0.0 +/- 0.0    0.12 +/- 0.36  ...      0.0 +/- 0.0      0.0 +/- 0.0
3      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
4      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
5   0.55 +/- 0.905    0.14 +/- 0.42  ...      0.0 +/- 0.0      0.0 +/- 0.0
6      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
7      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
8      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
9   0.62 +/- 1.318      0.2 +/- 0.6  ...      0.0 +/- 0.0      0.0 +/- 0.0
10  0.65 +/- 0.927   0.24 +/- 0.265  ...      0.0 +/- 0.0      0.0 +/- 0.0
11  1.02 +/- 1.558      0.0 +/- 0.0  ...      0.0 +/- 0.0   1.36 +/- 1.482
12     0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
13   0.32 +/- 0.96      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
14  0.78 +/- 1.191      0.0 +/- 0.0  ...  98.96 +/- 1.274      0.0 +/- 0.0
15     0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0  94.78 +/- 6.884
[16 rows x 16 columns]

Now I just want to be able to plot it as in the example below. I'd like to know how to do this. If I use sns.heatmap it will throw an error (TypeError: ufunc 'isnan' not supported for the input types...). Any ideas? Thanks.

enter image description here


Solution

  • So the easiest way I found was this (cm is the array of means and cms is the array of standard deviations):

    def plot_confusion_matrix(cm, cms,  classes,
                              cmap=plt.cm.Blues):
        """
        This function prints and plots the confusion matrix.
        Normalization can be applied by setting `normalize=True`.
        """
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)
    
        thresh = cm.max() / 2.
    
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                plt.text(j, i, '{0:.2f}'.format(cm[i, j]) + '\n$\pm$' + '{0:.2f}'.format(cms[i, j]),
                         horizontalalignment="center",
                         verticalalignment="center", fontsize=7,
                         color="white" if cm[i, j] > thresh else "black")
    
        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
    
    
    # Plot non-normalized confusion matrix
    plt.figure()
    plot_confusion_matrix(means, stds, classes=classes_list)
    

    enter image description here