Search code examples
pythonscikit-learnone-hot-encoding

scikit-learn: How to compose LabelEncoder and OneHotEncoder with a pipeline?


While preprocessing the labels for a machine learning classifying task, I need to one hot encode the labels which take string values. It happens that OneHotEncoder from sklearn.preprocessing or to_categorical from kera.np_utils require int inputs. This means that I need to precede the one hot encoder with a LabelEncoder. I have done it by hand with a custom class:

class LabelOneHotEncoder():
    def __init__(self):
        self.ohe = OneHotEncoder()
        self.le = LabelEncoder()
    def fit_transform(self, x):
        features = self.le.fit_transform( x)
        return self.ohe.fit_transform( features.reshape(-1,1))
    def transform( self, x):
        return self.ohe.transform( self.la.transform( x.reshape(-1,1)))
    def inverse_tranform( self, x):
        return self.le.inverse_transform( self.ohe.inverse_tranform( x))
    def inverse_labels( self, x):
        return self.le.inverse_transform( x)

I am confident there must a way of doing it within the sklearn API using a sklearn.pipeline, but when using:

LabelOneHotEncoder = Pipeline( [ ("le",LabelEncoder), ("ohe", OneHotEncoder)])

I get the error ValueError: bad input shape () from the OneHotEncoder. My guess is that the output of the LabelEncoder needs to be reshaped, by adding a trivial second axis. I am not sure how to add this feature though.


Solution

  • It's strange that they don't play together nicely... I'm surprised. I'd extend the class to return the reshaped data like you suggested.

    class ModifiedLabelEncoder(LabelEncoder):
    
        def fit_transform(self, y, *args, **kwargs):
            return super().fit_transform(y).reshape(-1, 1)
    
        def transform(self, y, *args, **kwargs):
            return super().transform(y).reshape(-1, 1)
    

    Then using the pipeline should work.

    pipe = Pipeline([("le", ModifiedLabelEncoder()), ("ohe", OneHotEncoder())])
    pipe.fit_transform(['dog', 'cat', 'dog'])
    

    https://github.com/scikit-learn/scikit-learn/blob/a24c8b46/sklearn/preprocessing/label.py#L39