Search code examples
pythonscikit-learntext-classificationrocmulticlass-classification

Can I plot ROC curve for multiclass text classification problem, without using OneVsRestClassifier?


I have a pickle file that when loaded returns a trained RandomForest classifier. I want to plot the ROC curve for the classes, but from what I read online, the classifier must be wrapped in scikit learn's OneVsRestClassifier. The problem is that since I already have the trained model I cannot wrap it in it to fit the model again.

So I would like to know if there is some workaround to plot the ROC curve. From my trained model I have y_test, y_proba. I also have x_test values.

  • The shape of my y_proba examples is: (6715, 5)

enter image description here

  • The shape of y_test is (6715, 5)

enter image description here

This is the output of the code @dx2-66 suggested:

enter image description here

enter image description here


Solution

  • I assume your y_test is single column with class id, and your y_proba has as much columns as there are classes (at least that's what you'd usually get from predict_proba().

    How about this? It should yield you OvR-style curves:

    from sklearn.metrics import roc_curve
    from sklearn.preprocessing import label_binarize
    import matplotlib.pyplot as plt
    
    classes = range(y_proba.shape[1])
    
    for i in classes:
        fpr, tpr, _ = roc_curve(label_binarize(y_test, classes=classes)[:,i], y_proba[:,i])
        plt.plot(fpr, tpr, alpha=0.7)
        plt.legend(classes)
    

    Update: solution for non-monotonic class labels:

    classes = sorted(list(y_test['label'].unique()))
    
    plt.plot([0, 1], linestyle='--')
    
    for i in range(len(classes)):
        fpr, tpr, _ = roc_curve(label_binarize(y_test, classes=classes)[:,i], y_proba.values[:,i])
        plt.plot(fpr, tpr, alpha=0.7)
        plt.legend(['baseline']+classes) # Fixed the baseline legend