Search code examples
machine-learningscikit-learnregressionmissing-dataimputation

How to impute missing values for multiple columns using a regressor?


This is an example of a larger dataset I have.

Imagine I have a dataframe with different columns and every column present missing values (NaN) in some part.

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor

df = pd.DataFrame({'a':[0.3, 0.2, 0.5, 0.1, 0.4, 0.5, np.nan, np.nan, np.nan, 0.6, 0.3, 0.5],
                   'b':[4, 3, 5, np.nan, np.nan, np.nan, 5, 6, 5, 8, 7, 4],
                   'c':[20, 25, 35, 30, 10, 18, 16, 22, 26, np.nan, np.nan, np.nan]})

I would like to predict these missing values using RandomForestRegressor, for example, with the other columns as features. In other words, when I see a sample with NaN, I want to use the value on the other two columns as features to predict this missing value.

I usually can do this for an unique feature, but I would like an automated way to do this for every column.

Thank you.


Solution

  • You can use the IterativeImputer from sklearn and provide the RandomForestRegressor for it in the estimator parameter:

    import pandas as pd
    import numpy as np
    from sklearn.ensemble import RandomForestRegressor
    from sklearn.experimental import enable_iterative_imputer
    from sklearn.impute import IterativeImputer
    
    df = pd.DataFrame({'a':[0.3, 0.2, 0.5, 0.1, 0.4, 0.5, np.nan, np.nan, np.nan, 0.6, 0.3, 0.5],
                       'b':[4, 3, 5, np.nan, np.nan, np.nan, 5, 6, 5, 8, 7, 4],
                       'c':[20, 25, 35, 30, 10, 18, 16, 22, 26, np.nan, np.nan, np.nan]})
    
    imp_mean = IterativeImputer(estimator=RandomForestRegressor(), random_state=0)
    imp_mean.fit(df)
    display(pd.DataFrame(imp_mean.transform(df)))
    

    This will then return the following dataframe, in which the nan values are imputed accordingly:

    0   1   2
    0   0.300   4.00    20.00
    1   0.200   3.00    25.00
    2   0.500   5.00    35.00
    3   0.100   3.69    30.00
    4   0.400   5.53    10.00
    5   0.500   5.78    18.00
    6   0.389   5.00    16.00
    7   0.455   6.00    22.00
    8   0.463   5.00    26.00
    9   0.600   8.00    21.02
    10  0.300   7.00    16.92
    11  0.500   4.00    29.98