Search code examples
scikit-learnrandom-forestrocauc

Difference in ROC-AUC scores in sklearn RandomForestClassifier vs. auc methods


I am receiving different ROC-AUC scores from sklearn's RandomForestClassifier and roc_curve, auc methods, respectively.

The following code got me an ROC-AUC (i.e. gs.best_score_) of 0.878:

def train_model(mod = None, params = None, features = None, 
        outcome = ...outcomes array..., metric = 'roc_auc'):
    gs = GridSearchCV(mod, params, scoring=metric, loss_func=None, score_func=None, 
        fit_params=None, n_jobs=-1, iid=True, refit=True, cv=10, verbose=0, 
        pre_dispatch='2*n_jobs', error_score='raise')
    gs.fit(...feature set df..., outcome)

    print gs.best_score_
    print gs.best_params_

    return gs

model = RandomForestClassifier(random_state=2000, n_jobs=-1)
features_to_include = [...list of column names...]

parameters = {
            'n_estimators': [...list...], 'max_depth':[...list...],
            'min_samples_split':[...list...], 'min_samples_leaf':[...list...]
            }

gs = train_model(mod = model, params = parameters, features = features_to_include)

Whereas, the following code got me an ROC-AUC of 0.97:

fpr = dict()
tpr = dict()
roc_auc = dict()
fpr['micro'], tpr['micro'], _ = roc_curve(...outcomes array..., 
                                    gs.predict_proba(...feature set df...)[:, 1])
roc_auc['micro'] = auc(fpr['micro'], tpr['micro'])

Why is there such a difference? Did I do something wrong with my code?

Thanks! Chris


Solution

  • They would return different values, for two reasons:

    1. since the GridSearchCV method splits your data into 10 groups (you are doing 10-fold cross-validation in your code), uses 9 for training, and reports the AUC on the last group. The best_score_ you get is just the highest-reported AUC reported as such (more info read here). Your roc_curve calculation reports the AUC on the entire set.

    2. The default cross-validation roc_auc is the macro version (see here), but your later computation computes the micro version.