Search code examples
pythonscikit-learnclassificationthresholdroc

How to set a threshold for a sklearn classifier based on ROC results?


I trained an ExtraTreesClassifier (gini index) using scikit-learn and it suits my needs fairly. Not so good accuracy, but using a 10-fold cross validation, AUC is 0.95. I would like to use this classifier on my work. I am quite new to ML, so please forgive me if I'm asking you something conceptually wrong.

I plotted some ROC curves, and by it, its seems I have a specific threshold where my classifier starts performing well. I'd like to set this value on the fitted classifier, so everytime I'd call predict, the classifiers use that threshold and I could believe in the FP and TP rates.

I also came to this post (scikit .predict() default threshold), where its stated that a threshold is not a generic concept for classifiers. But since the ExtraTreesClassifier has the method predict_proba, and the ROC curve is also related to thresdholds definition, it seems to me I should be available to specify it.

I did not find any parameter, nor any class/interface to use to do it. How can I set a threshold for it for a trained ExtraTreesClassifier (or any other one) using scikit-learn?

Many Thanks, Colis


Solution

  • This is what I have done:

    model = SomeSklearnModel()
    model.fit(X_train, y_train)
    predict = model.predict(X_test)
    predict_probabilities = model.predict_proba(X_test)
    fpr, tpr, _ = roc_curve(y_test, predict_probabilities)
    

    However, I am annoyed that predict chooses a threshold corresponding to 0.4% of true positives (false positives are zero). The ROC curve shows a threshold I like better for my problem where the true positives are approximately 20% (false positive around 4%). I then scan the predict_probabilities to find what probability value corresponds to my favourite ROC point. In my case this probability is 0.21. Then I create my own predict array:

    predict_mine = np.where(rf_predict_probabilities > 0.21, 1, 0)
    

    and there you go:

    confusion_matrix(y_test, predict_mine)
    

    returns what I wanted:

    array([[6927,  309],
           [ 621,  121]])