Search code examples
pythonscikit-learnpickledillscikit-learn-pipeline

Unable to load pickled custom estimator sklearn pipeline


I have a sklearn pipeline that uses custom column transformer, estimator and different lambda functions.

Because Pickle cannot serialize the lambda functions, I am using dill.

Here is the custom estimator I have:

class customOLS(BaseEstimator):
    def __init__(self, ols):
        self.estimator_ols = ols

    def fit(self, X, y):
        X = pd.DataFrame(X)
        y = pd.DataFrame(y)
        print('---- Training OLS')
        self.estimator_ols = self.estimator_ols(y,X).fit()
        #print('---- Training LR')
        #self.estimator_lr = self.estimator_lr.fit(X,y)
        return self

    def get_estimators(self):
        return self.estimator_ols #, self.estimator_lr
                
    def predict_ols(self, X):
        res = self.estimator_ols.predict(X)
        return res

pipeline2 = Pipeline(
        steps=[
            ('dropper', drop_cols),
            ('remover',feature_remover),
            ("preprocessor", preprocess_ppl),
            ("estimator", customOLS(sm.OLS))
            ]
    )

This is how I serilize it (I have to use open() otherwise it gives unsupportedOperation read write):

with open('data/baseModel_LR.joblib',"wb") as f:
        dill.dump(pipeline2, f)

But when I try to load the pickled object:

with open('data/baseModel_LR.joblib',"rb") as f:
        model = dill.load(f)
model

I get this error related to custom estimator:

AttributeError: 'customOLS' object has no attribute 'ols'

enter image description here


Solution

  • The problem lies on these two lines:

        def __init__(self, ols):
            self.estimator_ols = ols
    

    Here's an excerpt from the sklearn documentation, which explains why this won't work:

    All scikit-learn estimators have get_params and set_params functions. The get_params function takes no arguments and returns a dict of the __init__ parameters of the estimator, together with their values.

    Source.

    So, if you have a parameter named ols in your constructor, sklearn assumes that you have an attribute on your object, called ols. When you call get_params() on your object, (and repr() does call that) then that extracts the name of each variable from the constructor.

    To fix it, change the constructor to this:

        def __init__(self, estimator_ols):
            self.estimator_ols = estimator_ols
    

    When I do that, I am able to save and load the pipeline.