Problem: Given the below pyspark dataframe, is it possible to check whether row-wise whether "some_value" did increase (compared to the previous row) by using a window function (see example below)?
A solution without lag would be preferred as I will have multiple columns like "some_value" and I don't know in advance how many and their explicit names.
Example: Here I want to achive a column like "FLAG_INCREASE".
+---+----------+---+----------+
| id| datum|lfd|some_value| FLAG_INCREASE
+---+----------+---+----------+ ------------+
| 1|2015-01-01| 4| 20.0| 0
| 1|2015-01-06| 3| 10.0| 0
| 1|2015-01-07| 2| 25.0| 1
| 1|2015-01-12| 1| 30.0| 1
| 2|2015-01-01| 4| 5.0| 0
| 2|2015-01-06| 3| 30.0| 1
| 2|2015-01-12| 1| 20.0| 0
+---+----------+---+----------+--------------+
Code:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql import Row
row = Row("id", "datum", "lfd", "some_value", "some_value2")
df = spark.sparkContext.parallelize([
row(1, "2015-01-01", 4, 20.0, 20.0),
row(1, "2015-01-06", 3, 10.0, 20.0),
row(1, "2015-01-07", 2, 25.0, 20.0),
row(1, "2015-01-12", 1, 30.0, 20.0),
row(2, "2015-01-01", 4, 5.0, 20.0),
row(2, "2015-01-06", 3, 30.0, 20.0),
row(2, "2015-01-12", 1, 20.0, 20.0)
]).toDF().withColumn("datum", F.col("datum").cast("date"))
+---+----------+---+----------+
| id| datum|lfd|some_value|
+---+----------+---+----------+
| 1|2015-01-01| 4| 20.0|
| 1|2015-01-06| 3| 10.0|
| 1|2015-01-07| 2| 25.0|
| 1|2015-01-12| 1| 30.0|
| 2|2015-01-01| 4| 5.0|
| 2|2015-01-06| 3| 30.0|
| 2|2015-01-12| 1| 20.0|
+---+----------+---+----------+
you simply need a lag
:
from pyspark.sql import functions as F, Window
df = df.withColumn(
"FLAG_INCREASE",
F.when(
F.col("some_value")
> F.lag("some_value").over(Window.partitionBy("id").orderBy("datum")),
1,
).otherwise(0),
)