Search code examples
pythonpandasscikit-learnimputation

Why is SimpleImputer returning categorical data?


I'm imputing values into a dataframe using fillna for the numerical columns and SimpleImputer for the categorical columns. The problem is that when I ran the following code, I noticed that all of my features are categorical.

X = X.fillna(X.mean())
X_test = X_test.fillna(X_test.mean())

object_imputer = SimpleImputer(strategy="most_frequent")
X_temp = pd.DataFrame(object_imputer.fit_transform(X))
X_temp_test = pd.DataFrame(object_imputer.fit_transform(X_test))
X_temp.columns = X.columns
X_temp_test.columns = X_test.columns
X, X_test = X_temp, X_temp_test

The fillna works fine, but it's the SimpleImputer that is causing me problems.

Can you please tell me what the problem is and how I can fix it? Thanks in advance


Solution

  • Before I say anything else, note that you are fitting your imputer on X and then on X_test. You should never do this. Instead, you should always fit your imputer on the training data and then use that instance to transform both datasets (training and testing data).

    Having said that, your problem is that you are fitting and transforming all columns. As a consequence, the imputer is transforming all columns to type object.

    I believe this will solve your problem:

    # Impute NaNs of numeric columns
    X = X.fillna(X.mean())
    X_test = X_test.fillna(X_test.mean())
    
    # Subset of categorical columns
    cat = ['Loan_ID','Gender','Married','Dependents','Education','Self_Employed',
           'Credit_History','Loan_Status']
    # Fit Imputer on traing data and ONLY on categorical columns
    object_imputer = SimpleImputer(strategy='most_frequent').fit(X[cat])
    # Transform ONLY categorical columns
    X[cat] = object_imputer.transform(X[cat])
    X_test[cat] = object_imputer.transform(X_test[cat])
    

    As you can see, all columns have the correct data type now.

    X.dtypes
    Loan_ID               object
    Gender                object
    Married               object
    Dependents            object
    Education             object
    Self_Employed         object
    ApplicantIncome        int64
    CoapplicantIncome    float64
    LoanAmount           float64
    Loan_Amount_Term     float64
    Credit_History       float64
    Property_Area         object
    Loan_Status           object
    dtype: object