Search code examples
pythonscikit-learnconfusion-matrix

Accuracy and Confusion Matrix in Cross Validation


I am training a model to solve binary classification problem usign scikitlearn, and i wish to perform cross validation with 5 folds.

As metrics, i would like to get both the average accuracy and a confusion matrix over the 5 folds.

So, using cross_validate i can pass multiple metrics to the scoring parameter.

According to this link, i can def a function that returns the confusion matrix at each fold. In that piece of code, it uses X to predict some output through .predict(X). But shouldn't a test set, x_test, have been used instead? And since, at each fold, a different test set is obtained from cross_validate, i don't understand how we can just pass X to both confusion_matrix_scorer() and .predict(). Other question, is clf = svm here, right?


Solution

  • Docs state that a callable scorer should satisfy

    It can be called with parameters (estimator, X, y), where estimator is the model that should be evaluated, X is validation data, and y is the ground truth target for X (in the supervised case) or None (in the unsupervised case).

    When calling cross_validate, the cv folds are first generated and passed to independent fitting processes. Inside these processes, the test dataset is passed to a private _score method. From the source code

    test_scores = _score(estimator, X_test, y_test, scorer, error_score)
    

    which call the input scorrer with the defined parameters (estimator, X, y) source code

    scores = scorer(estimator, X_test, y_test)
    

    If you want to get both the average accuracy and a confusion matrix you can return these scores through a dictionary

    Example code

    from sklearn.metrics import accuracy_score, confusion_matrix
    
    def confusion_matrix_scorer(clf, X, y):
          y_pred = clf.predict(X)
          cm = confusion_matrix(y, y_pred)
          acc = accuracy_score(y, y_pred)
          return {
              'acc': acc,
              'tn': cm[0, 0], 
              'fp': cm[0, 1],
              'fn': cm[1, 0], 
              'tp': cm[1, 1]
          }