I am learning Scikit-Learn and I am trying to perform a grid search on a multi label classification problem. This is what I wrote:
from sklearn.model_selection import GridSearchCV
param_grid = [
{'randomforestclassifier__n_estimators': [3, 10, 30], 'randomforestclassifier__max_features': [2, 4, 5, 8]},
{'randomforestclassifier__bootstrap': [False], 'randomforestclassifier__n_estimators': [3, 10], 'randomforestclassifier__max_features': [2, 3, 4]}
]
rf_classifier = OneVsRestClassifier(
make_pipeline(RandomForestClassifier(random_state=42))
)
grid_search = GridSearchCV(rf_classifier, param_grid=param_grid, cv=5, scoring = 'f1_micro')
grid_search.fit(X_train_prepared, y_train)
However when I run it I get the following error:
ValueError: Invalid parameter randomforestclassifier for estimator
OneVsRestClassifier(estimator=Pipeline(steps=[('randomforestclassifier',
RandomForestClassifier(random_state=42))])). Check the list of available parameters
with `estimator.get_params().keys()`.
I tried to run also the grid_search.estimator.get_params().keys()
command but I just get a list of parameters containing the ones that I have written, therefore I am not sure what I should do.
Would you be able to suggest what the issue is and how I can run the grid search properly?
You would have been able to define param_grid
as you did in case rf_classifier
was a Pipeline
object. Quoting the Pipeline's docs
The purpose of the pipeline is to assemble several steps that can be cross-validated together while setting different parameters. For this, it enables setting parameters of the various steps using their names and the parameter name separated by a '__'.
In your case, instead, rf_classifier
is OneVsRestClassifier
instance. Therefore, before setting the parameters for the RFC, you'll need to be able to access the pipeline instance, which you can do via OneVsRestClassifier
estimator
's parameter, as follows:
param_grid = [
{'estimator__randomforestclassifier__n_estimators': [3, 10, 30],
'estimator__randomforestclassifier__max_features': [2, 4, 5, 8]
},
{'estimator__randomforestclassifier__bootstrap': [False],
'estimator__randomforestclassifier__n_estimators': [3, 10],
'estimator__randomforestclassifier__max_features': [2, 3, 4]
}
]