Search code examples
pythonscikit-learnpicklesupervised-learning

How to pickle or otherwise save an RFECV model after fitting for rapid classification of novel data


I am generating a predictive model for cancer diagnosis from a moderately large dataset (>4500 features). I have got the rfecv to work, providing me with a model that I can evaluate nicely using ROC curves, confusion matrices etc., and which is performing acceptably for classifying novel data.

please find a truncated version of my code below.

logo = LeaveOneGroupOut()
model = RFECV(LinearDiscriminantAnalysis(), step=1, cv=logo.split(X, y, groups=trial_number))
model.fit(X, y)

As I say, this works well and provides a model I'm happy with. The trouble is, I would like to be able to save this model, so that I don't need to do the lengthy retraining everytime I want to evaluate new data.

When I have tried to pickle a standard LDA or other model object, this has worked fine. When I try to pickle this RFECV object, however, I get the following error:

Traceback (most recent call last):
  File "/rds/general/user/***/home/data_analysis/analysis_report_generator.py", line 56, in <module>
    pickle.dump(key, file)
TypeError: cannot pickle 'generator' object

In trying to address this, I have spent a long time trying to RTFM, google extensively and dug as deep as I dared into Stack without any luck.

I would be grateful if anyone could identify what I could do to pickle this model successfully for future extraction and re-use, or whether there is an equivalent way to save the parameters of the feature-extracted LDA model for rapid analysis of new data.


Solution

  • This occurs because LeaveOneGroupOut().split(X, y, groups=groups) returns a generator object—which cannot be pickled for reasons previously discussed.

    To pickle it, you'd have to cast it to a finite number of splits with something like the following, or replace it with StratifiedKFold which does not have this issue.

    rfecv = RFECV(
        # ...
        cv=list(LeaveOneGroupOut().split(X, y, groups=groups)),
    )
    

    MRE putting all the pieces together (here I've assigned groups randomly):

    import pickle
    from sklearn.datasets import make_classification
    from sklearn.feature_selection import RFECV
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
    from sklearn.model_selection import LeaveOneGroupOut
    from numpy.random import default_rng
    
    rng = default_rng()
    
    X, y = make_classification(n_samples=500, n_features=15, n_informative=3, n_redundant=2, n_repeated=0, n_classes=8, n_clusters_per_class=1, class_sep=0.8, random_state=0)
    groups = rng.integers(0, 5, size=len(y))
    
    rfecv = RFECV(
        estimator=LinearDiscriminantAnalysis(),
        step=1,
        cv=list(LeaveOneGroupOut().split(X, y, groups=groups)),
        scoring="accuracy",
        min_features_to_select=1,
        n_jobs=4,
    )
    rfecv.fit(X, y)
    
    with open("rfecv_lda.pickle", "wb") as fh:
        pickle.dump(rfecv, fh)
    

    Side note: A better method would be to avoid pickling the RFECV in the first place. rfecv.transform(X) masks feature columns that the search deemed unnecessary. If you have >4500 features and only need 10, you might want to simplify your data pipeline elsewhere.