Search code examples
dataframeapache-sparkpysparkapache-spark-sql

How to change a value of a row in condition of a value in a previous row in an ordred dataframe by date of a unique id?


I need insights for how to do this in spark:

My dataframe is this

|ID |    DATE      |   State
|X  | 20-01-2023   |     N
|X  | 21-01-2023   |     S
|X  | 22-01-2023   |     S
|X  | 23-01-2023   |     N
|X  | 24-01-2023   |     E
|X  | 25-01-2023   |     E
|Y  | 20-01-2023   |     S
|Y  | 23-01-2023   |     S

the state is either : N neutral, S start, FS false start, E end or FE false end.

What i need is for each ID (X , Y ...) to order dates and to change the states based on the previous state in the previous row, so the start is a false start if was preceded b a start and end it is a false end if it was preceded by end. while neutral doesnt change anything;

The output should be something like this :

|ID  |   DATE         |State
|X   |20-01-2023      |  N
|X   |21-01-2023      |  S
|X   |22-01-2023      |  FS
|X   |23-01-2023      |  N
|X   |24-01-2023      |  E
|X   |25-01-2023      |  FE
|Y   |20-01-2023      |  S
|Y   |23-01-2023      |  FS

Any help is appreciated !


Solution

  • You can use the lag function to fetch the previous state, compare the previous state with the current state and make appropriate state changes using when conditions. Here's how you can do it:

    from pyspark.sql.window import Window
    from pyspark.sql.functions import lag, datediff, when, col, lit
    
    data = [
        ("X", "20-01-2023", "N"),
        ("X", "21-01-2023", "S"),
        ("X", "22-01-2023", "S"),
        ("X", "23-01-2023", "N"),
        ("X", "24-01-2023", "E"),
        ("X", "25-01-2023", "E"),
        ("Y", "20-01-2023", "S"),
        ("Y", "23-01-2023", "S"),
    ]
    # Create DataFrame
    df = spark.createDataFrame(data, ["id", "date", "state"])
    df.show()
    
    +---+----------+-----+
    | id|      date|state|
    +---+----------+-----+
    |  X|20-01-2023|    N|
    |  X|21-01-2023|    S|
    |  X|22-01-2023|    S|
    |  X|23-01-2023|    N|
    |  X|24-01-2023|    E|
    |  X|25-01-2023|    E|
    |  Y|20-01-2023|    S|
    |  Y|23-01-2023|    S|
    +---+----------+-----+
    
    # Fetch the previous state, if the previous state matches the current state
    # S: Change the latest state to FS
    # E: Change the latest state to FE
    # N: No change in state.
    df = (
        df.withColumn(
            "previous_state", lag("state").over(Window.partitionBy("id").orderBy("date"))
        )
        .withColumn(
            "state",
            when(
                col("state") == col("previous_state"),
                when(col("state") == "S", "FS")
                .when(col("state") == "E", "FE")
                .when(col("state") == "N", col("state")),
            ).otherwise(col("state")),
        )
        .drop("previous_state")
    )
    
    df.show()
    +---+----------+-----+
    | id|      date|state|
    +---+----------+-----+
    |  X|20-01-2023|    N|
    |  X|21-01-2023|    S|
    |  X|22-01-2023|   FS|
    |  X|23-01-2023|    N|
    |  X|24-01-2023|    E|
    |  X|25-01-2023|   FE|
    |  Y|20-01-2023|    S|
    |  Y|23-01-2023|   FS|
    +---+----------+-----+