Search code examples
pythontheorycross-validationgrid-search

Theory behind grid search with cross validation


Thanks to the help of stack overflow I successfully implemented grid search with cross validation for my decision tree model.

dtc = DecisionTreeClassifier()

parameter_grid = {'splitter': ['best', 'random'], 
        'min_samples_split': [15, 16, 17, 18, 19,20, 21, 22, 23, 24, 25,],
        'min_samples_leaf': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10,],
        'criterion': ['gini', 'entropy'],
        'random_state': [0]}

cross_validation = StratifiedKFold(n_splits=10)

grid_search = GridSearchCV(dtc, param_grid=parameter_grid, cv=cross_validation)

grid_search.fit(x, y)

My question concerns the theory behind it.

I know that k-fold cross validation splits my entire data set into k training data sets and corresponding validation data sets.

Then, I assume, that my code does something like that:

  1. apply the grid search k times, that is on each training data set of the k-folds.

  2. the best parameters of the grid search of each k-fold are applied on the corresponding validation data sets.

  3. the validation error is calculated for each validation dataset of the k-folds

Is this correct so far?

What are the values which I obtain with grid_search.best_score_ and grid_search.best_params_? Are these the best validation errors (grid_search.best_score_) from step 3 and the corresponding best values of the grid search (grid_search.best_params_) or some average value?

Any help or clarifications are highly welcome!


Solution

  • For each possible combination of the grid search (in your case 2*11*10*2*1=440) the train dataset is being splitted k times, and the average error on the k validation sets of each combination of hyper parameters is being calculated. The combination with the lowest average error is the grid_search.best_params_.

    for example:

    {'splitter': 'best', 
            'min_samples_split': 20,
            'min_samples_leaf': 9,
            'criterion':  'entropy',
            'random_state': 0}