Search code examples
pythonscikit-learndata-preprocessing

Scikit-learn pipeline returns list of zeroes


I am not able to understand why I am getting this wrong pipeline output.

Pipeline code:

my_pipeline = Pipeline(steps=[ 
    ('imputer', SimpleImputer(strategy='median')),
    ('std_scaler', StandardScaler())
])

Real data:

real = [[0.02498, 0.0, 1.89, 0.0, 0.518, 6.54, 59.7, 6.2669, 1.0, 422.0, 15.9, 389.96, 8.65]]

The pipeline output that I want:

want = [[-0.44228927, -0.4898311 , -1.37640684, -0.27288841, -0.34321545, 0.36524574, -0.33092752,  1.20235683, -1.0016859 ,  0.05733231, -1.21003475,  0.38110555, -0.57309194]]

But after running the below code:

getting = my_pipeline.fit_transform(real)

I am getting:

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

Solution

  • The problem

    This is an expected behavior because you define the data as a list.

    After the first step of the pipeline i.e. the SimpleImputer, the returned output is a numpy array with shape (1,13).

    si = SimpleImputer()
    si_out = si.fit_transform(real)
    
    si_out.shape
    # (1, 13)
    

    The returned (1,13) array is the problem here. This is because the StandardScaler, removes the mean and divides by the std each column. Thus, it "sees" 13 columns and the final output is all 0s since the means have been removed.

    sc = StandardScaler()
    sc.fit_transform(si_out)
    

    returns

    array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
    

    The solution

    It seems that you have only one variable/feature named real. Just reshape it before fitting.

    import numpy as np
    
    real = np.array([[0.02498, 0.0, 1.89, 0.0, 0.518, 6.54, 59.7, 6.2669, 1.0, 422.0, 15.9, 389.96, 8.65]]).reshape(-1,1)
    
    my_pipeline = Pipeline(steps=[ 
        ('imputer', SimpleImputer(strategy='median')),
        ('std_scaler', StandardScaler())
    ])
    my_pipeline.fit_transform(real)
    
    array([[-0.48677709],
           [-0.4869504 ],
           [-0.47383804],
           [-0.4869504 ],
           [-0.48335664],
           [-0.44157747],
           [-0.07276633],
           [-0.44347217],
           [-0.48001264],
           [ 2.44078289],
           [-0.37664007],
           [ 2.21849716],
           [-0.4269388 ]])