Search code examples
pythonpysparkgroup-bymissing-data

Forward and backward fill each group in PySpark


How to do forward and backward fill for each group in PySpark? For example, if we use the column id to group data and the column order to sort values with missing data:

df = spark.createDataFrame([
    ('a', 1.0, 1.0),
    ('b', 1.0, 2.0),
    ('a', 2.0, float("nan")),
    ('b', 2.0, float("nan")),
    ('a', 3.0, 3.0),
    ('b', 3.0, 4.0)],
    ["id", "order", "values"])

+---+-----+------+
| id|order|values|
+---+-----+------+
|  a|  1.0|   1.0|
|  b|  1.0|   2.0|
|  a|  2.0|   NaN|
|  b|  2.0|   NaN|
|  a|  3.0|   3.0|
|  b|  3.0|   4.0|
+---+-----+------+

Expected result for forward fill:

+---+-----+------+
| id|order|values|
+---+-----+------+
|  a|  1.0|   1.0|
|  b|  1.0|   2.0|
|  a|  2.0|   1.0|
|  b|  2.0|   2.0|
|  a|  3.0|   3.0|
|  b|  3.0|   4.0|
+---+-----+------+

Expected result for backward fill:

+---+-----+------+
| id|order|values|
+---+-----+------+
|  a|  1.0|   1.0|
|  b|  1.0|   2.0|
|  a|  2.0|   3.0|
|  b|  2.0|   4.0|
|  a|  3.0|   3.0|
|  b|  3.0|   4.0|
+---+-----+------+

Solution

  • Try replacing nan values with nulls first, then use coalesce combined with last and first functions (with ignoreNulls set to true) over windows like the example here :

    import pyspark.sql.functions as F
    
    ffill_window = "(partition by id order by order rows between unbounded preceding and current row)"
    bfill_window = "(partition by id order by order rows between current row and unbounded following)"
    
    (df
     .withColumn("values", F.expr("case when isnan(values) then null else values end"))
     .withColumn("values_ffill", F.expr(f"coalesce(values, last(values, true) over {ffill_window})"))
     .withColumn("values_bfill", F.expr(f"coalesce(values, first(values, true) over {bfill_window})"))
    ).show()
    
    # +---+-----+------+------------+------------+
    # | id|order|values|values_ffill|values_bfill|
    # +---+-----+------+------------+------------+
    # |  b|  1.0|   2.0|         2.0|         2.0|
    # |  b|  2.0|  null|         2.0|         4.0|
    # |  b|  3.0|   4.0|         4.0|         4.0|
    # |  a|  1.0|   1.0|         1.0|         1.0|
    # |  a|  2.0|  null|         1.0|         3.0|
    # |  a|  3.0|   3.0|         3.0|         3.0|
    # +---+-----+------+------------+------------+