Search code examples
pythonmachine-learningscikit-learngrid-searchhyperparameters

GridSearchCV/RandomizedSearchCV with partial_fit in sklearn


As per the documentation of RandomizedSearchCV and GridSearchCV modules of sklearn, they support only the fit method for the classifier which is passed to them and doesn't support the partial_fit method of the classifiers which can be used for training on an incremental basis. Currently, I am trying to use SGDClassifier which can be trained on incremental data using the partial_fit method and also find the best set of hyper-parameters for the same. I was just wondering why doesn't RandomizedSearchCV or GridSearchCV support partial_fit. I don't see any technical reasons as to why this cannot be done (please correct me if I am wrong here). Any leads will be really appreciated.


Solution

  • Yeah, technically you can write a GridSerachCV for partial_fit as well, but when you think about

    • what is that you are searching for?
    • what is that your are optimizing for?

    it becomes quite different from what we do with the .fit() approach. Here is my list of reason for not having partial_fit in GridsearchCV/RandomSearchCV.

    what is that you are searching for?

    1. When we optimize for the hyper parameters of a model for one batch of data, it could be sub-optimal for the final model (which is trained on complete data using multiple partial_fits). Now the problem becomes finding the best schedule for the hyper parameters i.e. what is the optimal value of the hyper parameter at each batch/time step. One example of this is the decaying learning rate in neural networks, where we train the model using multiple partial_fits and the hyper parameter - learning rate value would not be a single value but a series values that needs to be used for each time step/batch.

    2. Also you need to loop through the entire dataset multiple times (multiple epochs) to know the best scheduling of the hyper parameters. This needs a basic API change for GridSearchCV.

    what is that your are optimizing for?

    1. There is a need to change the evaluation metric of the model now. The metric could be achieving best performance at the end of all partial_fits or reaching the sweet-spot quickly (in fewer batches) for usual metric (precision, recall, f1-score, etc.), some combination of one and two. Hence, this also needs a API change for computing the single value for summarizing the performance of a model, which was trained using multiple partial_fits.