Search code examples
pythonmachine-learningscikit-learncross-validationgrid-search

Why does best_params_ in GridSearchCV ignore the variance?


The documentation of best_param_ in GridSearchCV states:

best_params_ : dict

Parameter setting that gave the best results on the hold out data.

From that, I assumed "best results" means best score (highest accuracy / lowest error) and lowest variance over my k-folds.

However, this is not case as we can see in cv_results_:

Image of a result table with 4 hyperparameter values having the best rank

Here best_param_ returns k=5 instead of k=9 where mean_test_score and the variance would be optimal.

I know I can implement my own scoring function or my own best_param function using the output of cv_results_. But what is the rationale behind not taking the variance into account in the first place?


I ran in that situation by applying KNN to iris dataset with 70% train split and a 3-fold cross-validation.


Edit: Example code:

import numpy as np
import pandas as pd
from sklearn import neighbors
from sklearn import model_selection
from sklearn import datasets

X = datasets.load_iris().data
y = datasets.load_iris().target

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, train_size=0.7, test_size=0.3, random_state=62)

knn_model = neighbors.KNeighborsClassifier()

param_grid = [{"n_neighbors" : np.arange(1, 31, 2)}]
grid_search = model_selection.GridSearchCV(knn_model, param_grid, cv=3, return_train_score=False)
grid_search.fit(X_train, y_train.ravel())

results = pd.DataFrame(grid_search.cv_results_)

k_opt = grid_search.best_params_.get("n_neighbors")
print("Value returned by best_param_:",k_opt)
results.head(6)

It results in a different table than the image above, but the situation is the same: for k=5 mean_test_score and std_test_score are optimal. However best_param_ returns k=1.


Solution

  • From the GridSearchCV source

        # Find the best parameters by comparing on the mean validation score:
        # note that `sorted` is deterministic in the way it breaks ties
        best = sorted(grid_scores, key=lambda x: x.mean_validation_score,
                      reverse=True)[0]
    

    It sorts by mean_val score and that's it. sorted() preserves the existing order for ties, so in this case k=1 is best.

    I agree with your thoughts and think a PR could be submitted to have better tie breaking logic.