Search code examples
pythonmachine-learningscikit-learnone-hot-encodingfeature-engineering

OneHotEncoding after LabelEncoding


In Sklearn how can I do OneHotEncoding after LabelEncoding in Sklearn.

What i have done so far is that i mapped all the string features of my dataset like such.

# Categorical boolean mask
categorical_feature_mask = X.dtypes==object
# filter categorical columns using mask and turn it into a list
categorical_cols = X.columns[categorical_feature_mask].tolist()

After that i applied this to the dataset columns, with indexing like such:

X[categorical_cols] = X[categorical_cols].apply(lambda col: le.fit_transform(col))

My results were not super good, so what I want to do, is that I want to use ÒneHotEncoding to see if performance is improved.

This is my code:

ohe = OneHotEncoder(categorical_features = categorical_cols)
X[categorical_cols] = ohe.fit_transform(df).toarray()

I have tried different approaches, but what i try to accomplish here is using the OneHotEncoding technique to overwrite the features.


Solution

  • OneHotEncoder directly supports categorical features, so no need to use a LabelEncoder prior to using it. Also note, that you should not use a LabelEncoder to encode features. Check LabelEncoder for features? for a detailed explanation on this. A LabelEncoder only makes sense on the actual target here.

    So select the categorical columns (df.select_dtypes is normally used here), and fit on the specified columns. Here's a sketch one how you could proceed:

    # OneHot encoding categorical columns
    oh_cols = df.select_dtypes('object').columns
    X_cat = df[oh_cols].to_numpy()
    oh = OneHotEncoder()
    one_hot_cols = oh.fit(X_cat)
    

    Then just call the transform method of the encoder. If you wanted to reconstruct the dataframe (as your code suggests) get_feature_names will give you the category names of the categorical features:

    df_prepr = pd.DataFrame(one_hot_cols.transform(X_cat).toarray(),
                            columns=one_hot_cols.get_feature_names(input_features=oh_cols))