Search code examples
pythonapache-sparkpysparkapache-spark-sql

Fill between known values and stop


How can I fill just between known values?

Consider the following example:

localdf = spark.createDataFrame(
   sc.parallelize(
       [
           [1, 24, None, None],
           [1, 23, None, None],
           [1, 22, 1, 1],
           [1, 21, None, 1],
           [1, 20, None, 1],
           [1, 19, 1, 1],
           [1, 18, None, None],
           [1, 17, None, None],
           [1, 16, 2, 2],
           [1, 15, None, None],
           [1, 14, None, None],
           [1, 13, 3, 3],
       ]
   ),
   ["ID", "Record", "Target", "ExpectedValue"],
)


# ffill
w = Window.partitionBy("ID").orderBy("Record")

# wrong attempt
localdf = localdf.withColumn(
   "TargetTry", F.last("Target", ignorenulls=True).over(w)
).orderBy("ID", F.desc("Record"))


localdf.show()
+---+------+------+-------------+---------+
| ID|Record|Target|ExpectedValue|TargetTry|
+---+------+------+-------------+---------+
|  1|    24|  NULL|         NULL|        1|
|  1|    23|  NULL|         NULL|        1|
|  1|    22|     1|            1|        1|
|  1|    21|  NULL|            1|        1|
|  1|    20|  NULL|            1|        1|
|  1|    19|     1|            1|        1|
|  1|    18|  NULL|         NULL|        2|
|  1|    17|  NULL|         NULL|        2|
|  1|    16|     2|            2|        2|
|  1|    15|  NULL|         NULL|        3|
|  1|    14|  NULL|         NULL|        3|
|  1|    13|     3|            3|        3|
+---+------+------+-------------+---------+

Solution

  • You did half of the answer, you need another last on that column but in the ascending order, then the result is the intersection of these 2:

    w = Window.partitionBy("ID").orderBy(desc("Record"))
    w2 = Window.partitionBy("ID").orderBy(asc("Record"))
    localdf = localdf.withColumn("last_desc", last("Target", ignorenulls=True).over(w)) \
        .withColumn("last_asc", last("Target", ignorenulls=True).over(w2)) \
        .withColumn("result", when(col("last_desc") == col("last_asc"), col("last_asc")).otherwise(None))
    localdf.orderBy("ID", desc("Record")).show()
    
    
    +---+------+------+-------------+---------+--------+------+
    | ID|Record|Target|ExpectedValue|last_desc|last_asc|result|
    +---+------+------+-------------+---------+--------+------+
    |  1|    24|  null|         null|     null|       1|  null|
    |  1|    23|  null|         null|     null|       1|  null|
    |  1|    22|     1|            1|        1|       1|     1|
    |  1|    21|  null|            1|        1|       1|     1|
    |  1|    20|  null|            1|        1|       1|     1|
    |  1|    19|     1|            1|        1|       1|     1|
    |  1|    18|  null|         null|        1|       2|  null|
    |  1|    17|  null|         null|        1|       2|  null|
    |  1|    16|     2|            2|        2|       2|     2|
    |  1|    15|  null|         null|        2|       3|  null|
    |  1|    14|  null|         null|        2|       3|  null|
    |  1|    13|     3|            3|        3|       3|     3|
    +---+------+------+-------------+---------+--------+------+