Search code examples
pythonmatplotlibjupyter-notebookconfusion-matrix

Incomplete confusion matrix when plotting with matshow


I'm trying to plot this confusión matrix:

[[25940  2141    84    19     3     0     0     1   184     4]
 [ 3525  6357   322    41     5     1     3     0   242     2]
 [  410  1484  1021    80     5     6     0     0   282     0]
 [   98   285   189   334     9     9     5     1   140     0]
 [   26    64    55    50   112    15     4     1    75     0]
 [   11    45    20    24     5   118     8     0    79     0]
 [    1     8     8     5     0    10    62     1    55     0]
 [    2     0     0     0     0     0     2     0     6     0]
 [  510   524   103    55     5     7     7     1 65350     0]
 [   62    13     2     1     0     0     1     0    11    13]]

Therefore, 10x10. Those 10 labels are:

[ 5  6  7  8  9 10 11 12 14 15]

I use the following code:

Get the confusion matrix

cm = confusion_matrix(y_test, y_pred, labels=labels)
print('Confusion Matrix of {} is:\n{}'.format(clf_name, cm))
print(labels)
plt.matshow(cm, interpolation='nearest')
ax = plt.gca()
ax.set_xticklabels([''] + labels.astype(str).tolist())
ax.set_yticklabels([''] + labels.astype(str).tolist())
plt.title('Confusion matrix of the {} classifier'.format(clf_name))
plt.colorbar(mat, extend='both')
plt.clim(0, 100)

And I only get a plot with labels from 5 to 9:

enter image description here

What's the problem here?

Relevant imports and configuration (I'm working with Jupyter, btw):

import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
plt.style.use('seaborn')
mpl.rcParams['figure.figsize'] = 8, 6

I tried downgrading to matplotlib 3.1.0, as I read that something went wrong on 3.1.1 about seaborn, but anyway the result is the same (and also if I change style to ggplot).


Solution

  • Matplotlib doesn't put a label at every tick (to prevent overlapping ticks in case they would be longer). You can force ticks at every column with ax.set_xticks(range(10)).

    Here is some example code, with calls adapted to matplotlib's "object oriented" interface. Also, some extra padding prevents the title not to bounce with the top tick labels. Note that the labels can be numerically, matplotlib automatically interprets them as the corresponding strings. ax.tick_params()can help to remove the tick marks at bottom and top (or, alternatively, also get them left and/or right). The sample code also uses a grid on the minor xticks to make separations.

    import matplotlib.pyplot as plt
    from matplotlib.ticker import MultipleLocator
    import numpy as np
    
    cm = np.random.randint(0, 25000, (10, 10)) * np.random.randint(0, 2, (10, 10))
    labels = np.array([5, 6, 7, 8, 9, 10, 11, 12, 14, 15])
    
    fig, ax = plt.subplots()
    mat = ax.matshow(cm, interpolation='nearest')
    mat.set_clim(0, 100)
    ax.set_xticks(range(10))
    ax.set_yticks(range(10))
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)
    ax.tick_params(axis='x', which='both', bottom=False, top=False)
    
    ax.grid(b=False, which='major', axis='both')
    ax.xaxis.set_minor_locator(MultipleLocator(0.5))
    ax.yaxis.set_minor_locator(MultipleLocator(0.5))
    ax.grid(b=True, which='minor', axis='both', lw=2, color='white')
    
    ax.set_title('Confusion matrix of the {} classifier'.format('clf_name'), pad=20)
    plt.colorbar(mat, extend='both')
    plt.show()
    

    example plot