Search code examples
pythonmachine-learningdata-visualizationconfusion-matrixfast-ai

How to change size of confusion matrix in fastai?


I am drawing a Confusion Matrix in fastai with following code:

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

But I end up with a super small matrix because I have around 20 categories:

small confusion matrix

I have found the related question for sklearns but don't know how to apply it to fastai (because we don't use pyplot directly.


Solution

  • If you check the code of the function ClassificationInterpretation.plot_confusion_matrix (in file fastai / interpret.py), this is what you see:

        def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', cmap="Blues", norm_dec=2,
                                  plot_txt=True, **kwargs):
            "Plot the confusion matrix, with `title` and using `cmap`."
            # This function is mainly copied from the sklearn docs
            cm = self.confusion_matrix()
            if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            fig = plt.figure(**kwargs)
            plt.imshow(cm, interpolation='nearest', cmap=cmap)
            plt.title(title)
            tick_marks = np.arange(len(self.vocab))
            plt.xticks(tick_marks, self.vocab, rotation=90)
            plt.yticks(tick_marks, self.vocab, rotation=0)
    

    The key here is the line fig = plt.figure(**kwargs), so basically, the function plot_confusion_matrix will propagate its parameters to the plot function.

    So you could use either one of these:

    • dpi=xxx (e.g. dpi=200)
    • figsize=(xxx, yyy)

    See this post on StackOverflow about the relations they have with each other: https://stackoverflow.com/a/47639545/1603480

    So in your case, you could just do:

    interp.plot_confusion_matrix(figsize=(12, 12))
    

    And the Confusion Matrix would look like:

    big confusion matrix

    FYI: this also applies to other plot functions, like

    interp.plot_top_losses(20, nrows=5, figsize=(25, 25))