Search code examples
scalaapache-sparkdataframeapache-spark-sqlimputation

Replace missing values with mean - Spark Dataframe


I have a Spark Dataframe with some missing values. I would like to perform a simple imputation by replacing the missing values with the mean for that column. I am very new to Spark, so I have been struggling to implement this logic. This is what I have managed to do so far:

a) To do this for a single column (let's say Col A), this line of code seems to work:

df.withColumn("new_Col", when($"ColA".isNull, df.select(mean("ColA"))
  .first()(0).asInstanceOf[Double])
  .otherwise($"ColA"))

b) However, I have not been able to figure out, how to do this for all the columns in my dataframe. I was trying out the Map function, but I believe it loops through each row of a dataframe

c) There is a similar question on SO - here. And while I liked the solution (using Aggregated tables and coalesce), I was very keen to know if there is a way to do this by looping through each column (I come from R, so looping through each column using a higher order functional like lapply seems more natural to me).

Thanks!


Solution

  • Spark >= 2.2

    You can use org.apache.spark.ml.feature.Imputer (which supports both mean and median strategy).

    Scala :

    import org.apache.spark.ml.feature.Imputer
    
    val imputer = new Imputer()
      .setInputCols(df.columns)
      .setOutputCols(df.columns.map(c => s"${c}_imputed"))
      .setStrategy("mean")
    
    imputer.fit(df).transform(df)
    

    Python:

    from pyspark.ml.feature import Imputer
    
    imputer = Imputer(
        inputCols=df.columns, 
        outputCols=["{}_imputed".format(c) for c in df.columns]
    )
    imputer.fit(df).transform(df)
    

    Spark < 2.2

    Here you are:

    import org.apache.spark.sql.functions.mean
    
    df.na.fill(df.columns.zip(
      df.select(df.columns.map(mean(_)): _*).first.toSeq
    ).toMap)
    

    where

    df.columns.map(mean(_)): Array[Column] 
    

    computes an average for each column,

    df.select(_: *).first.toSeq: Seq[Any]
    

    collects aggregated values and converts row to Seq[Any] (I know it is suboptimal but this is the API we have to work with),

    df.columns.zip(_).toMap: Map[String,Any] 
    

    creates aMap: Map[String, Any] which maps from the column name to its average, and finally:

    df.na.fill(_): DataFrame
    

    fills the missing values using:

    fill: Map[String, Any] => DataFrame 
    

    from DataFrameNaFunctions.

    To ingore NaN entries you can replace:

    df.select(df.columns.map(mean(_)): _*).first.toSeq
    

    with:

    import org.apache.spark.sql.functions.{col, isnan, when}
    
    
    df.select(df.columns.map(
      c => mean(when(!isnan(col(c)), col(c)))
    ): _*).first.toSeq