Search code examples
pythonscikit-learnconfusion-matrixmlflow

How to mlflow-autolog a sklearn ConfusionMatrixDisplay?


I'm trying to log the plot of a confusion matrix generated with scikit-learn for a test set using mlflow's support for scikit-learn.

For this, I tried something that resemble the code below (I'm using mlflow hosted on Databricks, and sklearn==1.0.1)

import sklearn.datasets
import pandas as pd
import numpy as np
import mlflow
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split

mlflow.set_tracking_uri("databricks")
mlflow.set_experiment("/Users/name.surname/plotcm")

data = sklearn.datasets.fetch_20newsgroups(categories=['alt.atheism', 'sci.space'])

df = pd.DataFrame(data = np.c_[data['data'], data['target']])\
       .rename({0:'text', 1:'class'}, axis = 'columns')

train, test = train_test_split(df)

my_pipeline = Pipeline([
    ('vectorizer', TfidfVectorizer()),
    ('classifier', SGDClassifier(loss='modified_huber')),
])

mlflow.sklearn.autolog()

from sklearn.metrics import ConfusionMatrixDisplay # should I import this after the call to `.autolog()`?

my_pipeline.fit(train['text'].values, train['class'].values)

cm = ConfusionMatrixDisplay.from_predictions(
      y_true=test["class"], y_pred=my_pipeline.predict(test["text"])
  )

while the confusion matrix for the training set is saved in my mlflow run, no png file is created in the mlflow frontend for the test set.

If I try to add

cm.figure_.savefig('test_confusion_matrix.png')
mlflow.log_artifact('test_confusion_matrix.png')

that does the job, but requires explicitly logging the artifact.

Is there an idiomatic/proper way to autolog the confusion matrix computed using a test set after my_pipeline.fit()?


Solution

  • The proper way to do this is to use mlflow.log_figure as a fluent API announced in MLflow 1.13.0. You can read the documentation here. This code will do the job.

    mlflow.log_figure(cm.figure_, 'test_confusion_matrix.png')
    

    This function implicitly store the image, and then calls log_artifact against that path, something like you did.