Search code examples
pythonapache-sparkpysparkdata-processing

Pyspark: Check if column values are monotonically increasing


Problem: Given the below pyspark dataframe, is it possible to check whether row-wise whether "some_value" did increase (compared to the previous row) by using a window function (see example below)?

A solution without lag would be preferred as I will have multiple columns like "some_value" and I don't know in advance how many and their explicit names.

Example: Here I want to achive a column like "FLAG_INCREASE".

+---+----------+---+----------+
| id|     datum|lfd|some_value| FLAG_INCREASE
+---+----------+---+----------+ ------------+
|  1|2015-01-01|  4|      20.0| 0
|  1|2015-01-06|  3|      10.0| 0
|  1|2015-01-07|  2|      25.0| 1
|  1|2015-01-12|  1|      30.0| 1
|  2|2015-01-01|  4|       5.0| 0
|  2|2015-01-06|  3|      30.0| 1
|  2|2015-01-12|  1|      20.0| 0
+---+----------+---+----------+--------------+

Code:

import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql import Row
 
 
row = Row("id", "datum", "lfd", "some_value", "some_value2")
df = spark.sparkContext.parallelize([
    row(1, "2015-01-01", 4, 20.0, 20.0),
    row(1, "2015-01-06", 3, 10.0, 20.0),
    row(1, "2015-01-07", 2, 25.0, 20.0),
    row(1, "2015-01-12", 1, 30.0, 20.0),
    row(2, "2015-01-01", 4, 5.0, 20.0),
    row(2, "2015-01-06", 3, 30.0, 20.0),
    row(2, "2015-01-12", 1, 20.0, 20.0)
]).toDF().withColumn("datum", F.col("datum").cast("date"))
 
+---+----------+---+----------+
| id|     datum|lfd|some_value|
+---+----------+---+----------+
|  1|2015-01-01|  4|      20.0|
|  1|2015-01-06|  3|      10.0|
|  1|2015-01-07|  2|      25.0|
|  1|2015-01-12|  1|      30.0|
|  2|2015-01-01|  4|       5.0|
|  2|2015-01-06|  3|      30.0|
|  2|2015-01-12|  1|      20.0|
+---+----------+---+----------+

Solution

  • you simply need a lag:

    from pyspark.sql import functions as F, Window
    
    df = df.withColumn(
        "FLAG_INCREASE",
        F.when(
            F.col("some_value")
            > F.lag("some_value").over(Window.partitionBy("id").orderBy("datum")),
            1,
        ).otherwise(0),
    )