Search code examples
pythonmachine-learningscikit-learngrid-searchgridsearchcv

Scikit GridSearchCV - How does fit() and predict() work in conjunction with ColumnTranformers and Pipelines


I am bit confused on how GridSearchCV actually works, so lets imagine an arbitrary regression problem, where I want to predict the price of a house:

Lets say we use a simple preprocessor, for target encoding on the training set: The target encoder should call fit_transform() on X_train and transform() on X_test to prevent data leakage.

preprocessor = ColumnTransformer(
    transformers=
    [      
        ('encoded_target_price', TargetEncoder(), ["Zipcodes"]),  
    ],
     remainder='passthrough',n_jobs=-1)

We use some pipeline with scaling, again the Scaler should work with respect to training and test set.

pipe = Pipeline(steps=[("preprocessor", preprocessor),
                       ("scaler", RobustScaler()),
                       ('clf', LinearSVR()),
                      ])

Initialize GridSearch with some arbitrary parameters:

gscv = GridSearchCV(estimator = pipe, 
                    param_grid = tuned_parameters,                
                    cv = kfold,                                   
                    n_jobs = -1,
                    random_state=seed
                    )

Now we can call gscv.fit(X_train, ytrain) and gscv.predict(X_test).

What I do not understand is how this works. For example by calling fit() the target encoder and the Scaler are fitted to the training set, but they are never transformed, so the data is never changed. How can GridSearch calculate scores based on the untransformed training set?

The predict method I do not understand at all. How can the prediction be made, without ever applying the transformations from the preprocessor to the test set X_test? I mean when I do some big transformations like scaling, encoding, etc. on the training set they HAVE to be done on the test set as well?

But Gridsearch internally only calls best_estimator_.predict(), so where does the .transform() on the test set happen?


Solution

  • The data transformation is implicitly applied when calling the pipeline's predict() function. It is clearly mentioned in the documentation:

    Apply transforms to the data, and predict with the final estimator

    So there is no need to explicitly transform the data. It is automatically done before the final estimator makes the prediction. There is also no data leakage since the pipeline will call the transform() method of each step when applying predict() to the data.