Search code examples
pythonscikit-learnpipeline

Error when using scikit-learn to use pipelines


I am trying to perform scaling using StandardScaler and define a KNeighborsClassifier(Create pipeline of scaler and estimator)

Finally, I want to create a Grid Search cross validator for the above where param_grid will be a dictionary containing n_neighbors as hyperparameter and k_vals as values.

def kNearest(k_vals):

    skf = StratifiedKFold(n_splits=5, random_state=23)

    svp = Pipeline([('ss', StandardScaler()),
                ('knc', neighbors.KNeighborsClassifier())])

    parameters = {'n_neighbors': k_vals}

    clf = GridSearchCV(estimator=svp, param_grid=parameters, cv=skf)

    return clf

But doing this will give me an error saying that

Invalid parameter n_neighbors for estimator Pipeline. Check the list of available parameters with `estimator.get_params().keys()`.

I've read the documentation, but still don't quite get what the error indicates and how to fix it.


Solution

  • You are right, this is not exactly well-documented by scikit-learn. (Zero reference to it in the class docstring.)

    If you use a pipeline as the estimator in a grid search, you need to use a special syntax when specifying the parameter grid. Specifically, you need to use the step name followed by a double underscore, followed by the parameter name as you would pass it to the estimator. I.e.

    '<named_step>__<parameter>': value
    

    In your case:

    parameters = {'knc__n_neighbors': k_vals}
    

    should do the trick.

    Here knc is a named step in your pipeline. There is an attribute that shows these steps as a dictionary:

    svp.named_steps
    
    {'knc': KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                metric_params=None, n_jobs=1, n_neighbors=5, p=2,
                weights='uniform'),
     'ss': StandardScaler(copy=True, with_mean=True, with_std=True)}
    

    And as your traceback alludes to:

    svp.get_params().keys()
    dict_keys(['memory', 'steps', 'ss', 'knc', 'ss__copy', 'ss__with_mean', 'ss__with_std', 'knc__algorithm', 'knc__leaf_size', 'knc__metric', 'knc__metric_params', 'knc__n_jobs', 'knc__n_neighbors', 'knc__p', 'knc__weights'])
    

    Some official references to this: