Search code examples
pysparkapache-spark-sqlwindow-functionsspark3

Create a lookup column in pyspark


I am trying to create a new column in a pyspark dataframe that "looks up" the next value in the same dataframe, and duplicates it to all next rows, until the next event happened.

I have used used windowing functions as follows, but still no luck with getting the next value on the column:

condition = (col("col2") == 'event_start_ind')
w=Window.partitionBy("col2").orderBy(*[when(condition, lit(1)).desc()])

df.select(["timestamp",
           "col1",
           "col2",
           "col3"
          ]).withColumn("col4", when(condition, lead("col3",1).over(w))) \
.orderBy("timestamp") \
.show(500, truncate=False)

Apparently it won't lookup the "next" event properly. Any ideas on possible approaches?

A sample dataframe would be:

timestamp col1 col2 col3
2021-02-02 01:03:55 s1 null null
2021-02-02 01:04:16.952854 s1 other_ind null
2021-02-02 01:04:32.398155 s1 null null
2021-02-02 01:04:53.793089 s1 event_start_ind event_1_value
2021-02-02 01:05:10.936913 s1 null null
2021-02-02 01:05:36 s1 other_ind null
2021-02-02 01:05:42 s1 null null
2021-02-02 01:05:43 s1 null null
2021-02-02 01:05:44 s1 event_start_ind event_2_value
2021-02-02 01:05:46.623198 s1 null null
2021-02-02 01:06:50 s1 null null
2021-02-02 01:07:19.607685 s1 null null

The desired result would be:

timestamp col1 col2 col3 col4
2021-02-02 01:03:55 s1 null null event_1_value
2021-02-02 01:04:16.952854 s1 other_ind null event_1_value
2021-02-02 01:04:32.398155 s1 null null event_1_value
2021-02-02 01:04:53.793089 s1 event_start_ind event_1_value event_1_value
2021-02-02 01:05:10.936913 s1 null null event_2_value
2021-02-02 01:05:36 s1 other_ind null event_2_value
2021-02-02 01:05:42 s1 null null event_2_value
2021-02-02 01:05:43 s1 null null event_2_value
2021-02-02 01:05:44 s1 event_start_ind event_2_value event_2_value
2021-02-02 01:05:46.623198 s1 null null null
2021-02-02 01:06:50 s1 null null null
2021-02-02 01:07:19.607685 s1 null null null

Solution

  • It looks like you don't have a partition for your window, and the events do not have the same amount of records. Considering this, the solution that comes to my mind is to use the position of each event start to retrieve the respective value.

    Considering the sorting by timestamp, we extract the position of each line:

    from pyspark.sql import Window
    from pyspark.sql.functions import col, rank, collect_list, expr
    
    df = (
      spark.createDataFrame(
        [
            { 'timestamp': '2021-02-02 01:03:55', 'col1': 's1' },
            { 'timestamp': '2021-02-02 01:04:16.952854', 'col1': 's1', 'col2': 'other_ind'},
            { 'timestamp': '2021-02-02 01:04:32.398155', 'col1': 's1'},
            { 'timestamp': '2021-02-02 01:04:53.793089', 'col1': 's1', 'col2': 'event_start_ind', 'col3': 'event_1_value'},
            { 'timestamp': '2021-02-02 01:05:10.936913', 'col1': 's1'},
            { 'timestamp': '2021-02-02 01:05:36', 'col1': 's1', 'col2': 'other_ind'},
            { 'timestamp': '2021-02-02 01:05:42', 'col1': 's1'},
            { 'timestamp': '2021-02-02 01:05:43', 'col1': 's1'},
            { 'timestamp': '2021-02-02 01:05:44', 'col1': 's1', 'col2': 'event_start_ind', 'col3': 'event_2_value'},
            { 'timestamp': '2021-02-02 01:05:46.623198', 'col1': 's1'},
            { 'timestamp': '2021-02-02 01:06:50', 'col1': 's1'},
            { 'timestamp': '2021-02-02 01:07:19.607685', 'col1': 's1'}
        ]
      )
      .withColumn('timestamp', col('timestamp').cast('timestamp'))
      .withColumn("line", rank().over(Window.orderBy("timestamp")))
    )
    
    df.show(truncate=False)
    
    +----+--------------------------+---------------+-------------+----+
    |col1|timestamp                 |col2           |col3         |line|
    +----+--------------------------+---------------+-------------+----+
    |s1  |2021-02-02 01:03:55       |null           |null         |1   |
    |s1  |2021-02-02 01:04:16.952854|other_ind      |null         |2   |
    |s1  |2021-02-02 01:04:32.398155|null           |null         |3   |
    |s1  |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|4   |
    |s1  |2021-02-02 01:05:10.936913|null           |null         |5   |
    |s1  |2021-02-02 01:05:36       |other_ind      |null         |6   |
    |s1  |2021-02-02 01:05:42       |null           |null         |7   |
    |s1  |2021-02-02 01:05:43       |null           |null         |8   |
    |s1  |2021-02-02 01:05:44       |event_start_ind|event_2_value|9   |
    |s1  |2021-02-02 01:05:46.623198|null           |null         |10  |
    |s1  |2021-02-02 01:06:50       |null           |null         |11  |
    |s1  |2021-02-02 01:07:19.607685|null           |null         |12  |
    +----+--------------------------+---------------+-------------+----+
    

    After that we identify each event start:

    df_event_start = (
        df.filter(col("col2") == 'event_start_ind')
        .select(
            col("line").alias("event_start_line"),
            col("col3").alias("event_value")
        )
    )
    df_event_start.show()
    
    +----------------+-------------+
    |event_start_line|  event_value|
    +----------------+-------------+
    |               4|event_1_value|
    |               9|event_2_value|
    +----------------+-------------+
    

    Uses event_start information to find the next valid event start:

    df_with_event_starts = (
        df.join(
            df_event_start.select(collect_list('event_start_line').alias("event_starts"))
        )
        .withColumn("next_valid_event", expr("element_at(filter(event_starts, x -> x >= line), 1)"))
    )
    
    df_with_event_starts.show(truncate=False)
    
    +----+--------------------------+---------------+-------------+----+------------+----------------+
    |col1|timestamp                 |col2           |col3         |line|event_starts|next_valid_event|
    +----+--------------------------+---------------+-------------+----+------------+----------------+
    |s1  |2021-02-02 01:03:55       |null           |null         |1   |[4, 9]      |4               |
    |s1  |2021-02-02 01:04:16.952854|other_ind      |null         |2   |[4, 9]      |4               |
    |s1  |2021-02-02 01:04:32.398155|null           |null         |3   |[4, 9]      |4               |
    |s1  |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|4   |[4, 9]      |4               |
    |s1  |2021-02-02 01:05:10.936913|null           |null         |5   |[4, 9]      |9               |
    |s1  |2021-02-02 01:05:36       |other_ind      |null         |6   |[4, 9]      |9               |
    |s1  |2021-02-02 01:05:42       |null           |null         |7   |[4, 9]      |9               |
    |s1  |2021-02-02 01:05:43       |null           |null         |8   |[4, 9]      |9               |
    |s1  |2021-02-02 01:05:44       |event_start_ind|event_2_value|9   |[4, 9]      |9               |
    |s1  |2021-02-02 01:05:46.623198|null           |null         |10  |[4, 9]      |null            |
    |s1  |2021-02-02 01:06:50       |null           |null         |11  |[4, 9]      |null            |
    |s1  |2021-02-02 01:07:19.607685|null           |null         |12  |[4, 9]      |null            |
    +----+--------------------------+---------------+-------------+----+------------+----------------+
    

    And finally retrieves the correct value:

    (
        df_with_event_starts.join(
            df_event_start,
            col("next_valid_event") == col("event_start_line"),
            how="left"
        )
        .drop("line", "event_starts", "next_valid_event", "event_start_line")
        .show(truncate=False)
    )
    
    +----+--------------------------+---------------+-------------+-------------+
    |col1|timestamp                 |col2           |col3         |event_value  |
    +----+--------------------------+---------------+-------------+-------------+
    |s1  |2021-02-02 01:03:55       |null           |null         |event_1_value|
    |s1  |2021-02-02 01:04:16.952854|other_ind      |null         |event_1_value|
    |s1  |2021-02-02 01:04:32.398155|null           |null         |event_1_value|
    |s1  |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|event_1_value|
    |s1  |2021-02-02 01:05:10.936913|null           |null         |event_2_value|
    |s1  |2021-02-02 01:05:36       |other_ind      |null         |event_2_value|
    |s1  |2021-02-02 01:05:42       |null           |null         |event_2_value|
    |s1  |2021-02-02 01:05:43       |null           |null         |event_2_value|
    |s1  |2021-02-02 01:05:44       |event_start_ind|event_2_value|event_2_value|
    |s1  |2021-02-02 01:05:46.623198|null           |null         |null         |
    |s1  |2021-02-02 01:06:50       |null           |null         |null         |
    |s1  |2021-02-02 01:07:19.607685|null           |null         |null         |
    +----+--------------------------+---------------+-------------+-------------+
    

    This solution will bring you problems in processing large volumes. If you can figure out a key for each event, I advise you to continue with your initial solution using window functions. If this happens, you can test the last or first sql function (ignoring the null values).

    Hopefully, someone will help you with a better solution.

    Tip: Making the data frame creation scripts available in the question is helpful.