I'm trying to use Column Transformer with OneHotEncoder to transform my categorical data :
A quick look at my data :
I want to do one-hot-encoding for 3 features : 'sex' , 'smoker' , 'region', so I use Column Transformer by scikit-learn. ( I don't want to want to seperate numerical one and categorical one than transform them seperately, I just want to perform them on a single dataset)
My code :
cat_feature = X.select_dtypes(include = 'object') #select only categorical columns
enc = ColumnTransformer([ ('one_hot_encoder' , OneHotEncoder() , cat_feature ) ] ,
remainder = 'passthrough')
X_transformed = enc.fit_transform(X) # transformed version of original data
My problem is that, X_transformed
is then removed all the feature names which is little bit confusing for me to debug :
So is there anyway to retain my columns' names after doing this transformation? I want to incorporate this transformer into a pipeline so I can't use pd.get_dummies
.
Thank you!!
User will have to write custom Transformer
which does passthrough and supports get_feature_names
Steps:
Transformer
which will return pass through columns names via get_feature_names
remainder = 'passthrough'
but rather use our custom Transformer
Use enc.get_feature_names()
to get the feature list.
Sample:
from sklearn.base import BaseEstimator
df = pd.DataFrame({
'age': [1,2,3,4],
'sex': ['male', 'female']*2,
'bmi': [1.1,2.2,3.3,4.4],
'children': [1]*4,
'smoker': ['yes', 'no']*2
})
cat_features = df.select_dtypes(include = 'object').columns
passthrough_features = [c for c in df.columns if c not in cat_features]
class PassthroughTransformer(BaseEstimator):
def fit(self, X, y = None):
self.cols = X.columns
return self
def transform(self, X, y = None):
self.cols = X.columns
return X.values
def get_feature_names(self):
return self.cols
enc = ColumnTransformer([ ('1hot' , OneHotEncoder() , cat_features ),
('pass' , PassthroughTransformer(), passthrough_features)])
X_transformed = enc.fit_transform(df)
pd.DataFrame(X_transformed, columns=enc.get_feature_names())
Output:
1hot__x0_female 1hot__x0_male 1hot__x1_no 1hot__x1_yes pass__age pass__bmi pass__children
0 0.0 1.0 0.0 1.0 1.0 1.1 1.0
1 1.0 0.0 1.0 0.0 2.0 2.2 1.0
2 0.0 1.0 0.0 1.0 3.0 3.3 1.0
3 1.0 0.0 1.0 0.0 4.0 4.4 1.0