Search code examples
pythonmachine-learningscikit-learnclassificationcross-validation

How to get multi-class roc_auc in cross validate in sklearn?


I have a classification problem where I want to get the roc_auc value using cross_validate in sklearn. My code is as follows.

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
y = iris.target

from sklearn.ensemble import RandomForestClassifier
clf=RandomForestClassifier(random_state = 0, class_weight="balanced")

from sklearn.model_selection import cross_validate
cross_validate(clf, X, y, cv=10, scoring = ('accuracy', 'roc_auc'))

However, I get the following error.

ValueError: multiclass format is not supported

Please note that I selected roc_auc specifically is that it supports both binary and multiclass classification as mentioned in: https://scikit-learn.org/stable/modules/model_evaluation.html

I have binary classification dataset too. Please let me know how to resolve this error.

I am happy to provide more details if needed.


Solution

  • By default multi_class='raise' so you need explicitly to change this.

    From the docs:

    multi_class {‘raise’, ‘ovr’, ‘ovo’}, default=’raise’

    Multiclass only. Determines the type of configuration to use. The default value raises an error, so either 'ovr' or 'ovo' must be passed explicitly.

    'ovr':

    Computes the AUC of each class against the rest [3] [4]. This treats the multiclass case in the same way as the multilabel case. Sensitive to class imbalance even when average == 'macro', because class imbalance affects the composition of each of the ‘rest’ groupings.

    'ovo':

    Computes the average AUC of all possible pairwise combinations of classes [5]. Insensitive to class imbalance when average == 'macro'.


    Solution:

    Use make_scorer (docs):

    from sklearn import datasets
    iris = datasets.load_iris()
    X = iris.data[:, :2]  # we only take the first two features.
    y = iris.target
    
    from sklearn.ensemble import RandomForestClassifier
    clf=RandomForestClassifier(random_state = 0, class_weight="balanced")
    
    from sklearn.metrics import make_scorer
    from sklearn.metrics import roc_auc_score
    
    myscore = make_scorer(roc_auc_score, multi_class='ovo',needs_proba=True)
    
    from sklearn.model_selection import cross_validate
    cross_validate(clf, X, y, cv=10, scoring = myscore)