How to do forward and backward fill for each group in PySpark? For example, if we use the column id
to group data and the column order
to sort values with missing data:
df = spark.createDataFrame([
('a', 1.0, 1.0),
('b', 1.0, 2.0),
('a', 2.0, float("nan")),
('b', 2.0, float("nan")),
('a', 3.0, 3.0),
('b', 3.0, 4.0)],
["id", "order", "values"])
+---+-----+------+
| id|order|values|
+---+-----+------+
| a| 1.0| 1.0|
| b| 1.0| 2.0|
| a| 2.0| NaN|
| b| 2.0| NaN|
| a| 3.0| 3.0|
| b| 3.0| 4.0|
+---+-----+------+
Expected result for forward fill:
+---+-----+------+
| id|order|values|
+---+-----+------+
| a| 1.0| 1.0|
| b| 1.0| 2.0|
| a| 2.0| 1.0|
| b| 2.0| 2.0|
| a| 3.0| 3.0|
| b| 3.0| 4.0|
+---+-----+------+
Expected result for backward fill:
+---+-----+------+
| id|order|values|
+---+-----+------+
| a| 1.0| 1.0|
| b| 1.0| 2.0|
| a| 2.0| 3.0|
| b| 2.0| 4.0|
| a| 3.0| 3.0|
| b| 3.0| 4.0|
+---+-----+------+
Try replacing nan
values with null
s first, then use coalesce
combined with last
and first
functions (with ignoreNulls set to true) over windows like the example here :
import pyspark.sql.functions as F
ffill_window = "(partition by id order by order rows between unbounded preceding and current row)"
bfill_window = "(partition by id order by order rows between current row and unbounded following)"
(df
.withColumn("values", F.expr("case when isnan(values) then null else values end"))
.withColumn("values_ffill", F.expr(f"coalesce(values, last(values, true) over {ffill_window})"))
.withColumn("values_bfill", F.expr(f"coalesce(values, first(values, true) over {bfill_window})"))
).show()
# +---+-----+------+------------+------------+
# | id|order|values|values_ffill|values_bfill|
# +---+-----+------+------------+------------+
# | b| 1.0| 2.0| 2.0| 2.0|
# | b| 2.0| null| 2.0| 4.0|
# | b| 3.0| 4.0| 4.0| 4.0|
# | a| 1.0| 1.0| 1.0| 1.0|
# | a| 2.0| null| 1.0| 3.0|
# | a| 3.0| 3.0| 3.0| 3.0|
# +---+-----+------+------------+------------+