I need insights for how to do this in spark:
My dataframe is this
|ID | DATE | State
|X | 20-01-2023 | N
|X | 21-01-2023 | S
|X | 22-01-2023 | S
|X | 23-01-2023 | N
|X | 24-01-2023 | E
|X | 25-01-2023 | E
|Y | 20-01-2023 | S
|Y | 23-01-2023 | S
the state is either : N neutral, S start, FS false start, E end or FE false end.
What i need is for each ID (X , Y ...) to order dates and to change the states based on the previous state in the previous row, so the start is a false start if was preceded b a start and end it is a false end if it was preceded by end. while neutral doesnt change anything;
The output should be something like this :
|ID | DATE |State
|X |20-01-2023 | N
|X |21-01-2023 | S
|X |22-01-2023 | FS
|X |23-01-2023 | N
|X |24-01-2023 | E
|X |25-01-2023 | FE
|Y |20-01-2023 | S
|Y |23-01-2023 | FS
Any help is appreciated !
You can use the lag
function to fetch the previous state, compare the previous state with the current state and make appropriate state changes using when
conditions. Here's how you can do it:
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, datediff, when, col, lit
data = [
("X", "20-01-2023", "N"),
("X", "21-01-2023", "S"),
("X", "22-01-2023", "S"),
("X", "23-01-2023", "N"),
("X", "24-01-2023", "E"),
("X", "25-01-2023", "E"),
("Y", "20-01-2023", "S"),
("Y", "23-01-2023", "S"),
]
# Create DataFrame
df = spark.createDataFrame(data, ["id", "date", "state"])
df.show()
+---+----------+-----+
| id| date|state|
+---+----------+-----+
| X|20-01-2023| N|
| X|21-01-2023| S|
| X|22-01-2023| S|
| X|23-01-2023| N|
| X|24-01-2023| E|
| X|25-01-2023| E|
| Y|20-01-2023| S|
| Y|23-01-2023| S|
+---+----------+-----+
# Fetch the previous state, if the previous state matches the current state
# S: Change the latest state to FS
# E: Change the latest state to FE
# N: No change in state.
df = (
df.withColumn(
"previous_state", lag("state").over(Window.partitionBy("id").orderBy("date"))
)
.withColumn(
"state",
when(
col("state") == col("previous_state"),
when(col("state") == "S", "FS")
.when(col("state") == "E", "FE")
.when(col("state") == "N", col("state")),
).otherwise(col("state")),
)
.drop("previous_state")
)
df.show()
+---+----------+-----+
| id| date|state|
+---+----------+-----+
| X|20-01-2023| N|
| X|21-01-2023| S|
| X|22-01-2023| FS|
| X|23-01-2023| N|
| X|24-01-2023| E|
| X|25-01-2023| FE|
| Y|20-01-2023| S|
| Y|23-01-2023| FS|
+---+----------+-----+