I am trying to get an idea of the inner workings of a scikit learn Pipeline
.
Consider the below data set and pipeline construction.
data = pd.DataFrame({
'Name': ['Alice', 'Bob', 'Charlie'],
'Age' : [30, 40, 37],
'City': ['Amsterdam', 'Berlin', 'Copenhagen']
})
OHE = OneHotEncoder()
ct = ColumnTransformer(transformers = [('One-hot', OHE, ['City'])])
ppln = Pipeline(steps=[('preprocessor', ct),
('estimator', <some_estimator>())
])
The next step is then a fit of sorts:
model = ppln.fit(data)
To the best of my understanding, the above is a series of steps taken that ends in <some_estimator>.fit(???)
.
My question is: what will the
???
actually be (or how can I determine what it will be).
I am unsure about this because I don't exactly know how the 'preprocessor'
interacts with the data. On its own, ct.transform(data)
returns a matrix like object. In this case that would be
[[1,0,0],
[0,1,0],
[0,0,1]]
I am guessing that something will happen that eventually makes it so that:
??? =
[['Alice' , 30, 1, 0, 0],
['Bob' , 40, 0, 1, 0],
['Charlie', 37, 0, 0, 1]]
I would like to do better than guessing and know for sure what happens and gain a more complete view of what is happening under water when using a Pipeline
.
What you are missing is the remainder
parameter of ColumnTransformer
.
By default, ColumnTransformer
drop all unprocessing data:
By default, only the specified columns in
transformers
are transformed and combined in the output, and the non-specified columns are dropped. (default of'drop'
).
So when you don't specify remainder
, you have only the OHE of City column:
# Name, Age, City --> ColumnTransformer --> OHE(City)
ct = ColumnTransformer(transformers = [('One-hot', OHE, ['City'])])
ct.fit_transform(data)
# Output
array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
Using remainder='passthrough'
give you:
# Name, Age, City --> ColumnTransformer --> OHE(City), Name, Age
ct = ColumnTransformer(transformers = [('One-hot', OHE, ['City'])],
remainder='passthrough')
ct.fit_transform(data)
# Output
array([[1.0, 0.0, 0.0, 'Alice', 30],
[0.0, 1.0, 0.0, 'Bob', 40],
[0.0, 0.0, 1.0, 'Charlie', 37]], dtype=object)