Search code examples
python-3.xscikit-learnimputation

How to implement a sklearn transformer with sklearn.base.SimpleImputer but returns a pandas DataFrame


I want to implement a customer transformer with sklearn imputer, e.g., sklearn.base.SimpleImputer.

The output should be a dataframe,

I have the following code, but not sure if this is correct

class DFSimpleImputer(TransformerMixin):

def __init__(self, *args, **kwargs):
    self.imp = SimpleImputer(*args, **kwargs)

def fit(self, X, y=None, **fit_params):
    self.imp.fit(X)
    return self

def transform(self, X):
    # assumes X is a DataFrame
    Ximp = self.imp.transform(X)
    Xfilled = pd.DataFrame(Ximp, index=X.index, columns=X.columns)
    return Xfilled

Solution

  • Yes, the code above works and returns a dataframe. The question that you need to ask is why you need a DataFrame when build transformers (yes, it adds label for easy reading). Maybe nparray is better since you may encounter sparse matrix and DataFrame will eat up all your RAM.