Search code examples
pythondataframepysparkfillna

PySpark - Replace Null values with the mean of corresponding row


I have the following pyspark dataframe :

df
col1 col2 col3
1      2    3
4    None   6
7      8   None

I want to replace None (or Null) values by the mean of the row they are into. the output would look like :

df_result
col1 col2 col3
1      2    3
4      5    6
7      8   7.5

Everything I tried leads to the error 'Column is not iterable' or 'Invalid argument, not a string or column'. Many thanks for your help !


Solution

  • Using only spark-in built functions:

    Try with Higher order array (aggregate) functions and count the number of non null elements(using lambda functions) in the array.

    Then calculate the mean of the elements.

    Finally divide and replace the nulls(case + when statement) with mean value.

    Example:

    df.show(10,False)
    #+----+----+----+
    #|col1|col2|col3|
    #+----+----+----+
    #|1   |2   |3   |
    #|4   |null|6   |
    #|7   |8   |null|
    #+----+----+----+
    
    #add nulls_count filed to check how many null values are there in all the columns
    #cast all columns as an array
    #sum_elems as sum of all elements of array
    #calculate the mean based on non null values
    df1 = df.withColumn("nulls_count",size(filter(array(*[isnull(col(c)) for c in df.columns]), lambda x: x))).\
      withColumn("arr_vals",array(*[coalesce(col(c),lit(0)) for c in df.columns])).\
      withColumn("sum_elems",expr("aggregate(arr_vals,cast(0 as bigint),(acc, x) -> acc + x)")).\
      withColumn("mean_val",expr('round(sum_elems/((size(arr_vals))-nulls_count),1)'))
    
    df1.select([when(col(c).isNull(), col("mean_val")).otherwise(col(c)).alias(c) for c in df.columns]).show(10,False)
    #+----+----+----+
    #|col1|col2|col3|
    #+----+----+----+
    #|1.0 |2.0 |3.0 |
    #|4.0 |5.0 |6.0 |
    #|7.0 |8.0 |7.5 |
    #+----+----+----+