I am trying to create a new column in a pyspark dataframe that "looks up" the next value in the same dataframe, and duplicates it to all next rows, until the next event happened.
I have used used windowing functions as follows, but still no luck with getting the next value on the column:
condition = (col("col2") == 'event_start_ind')
w=Window.partitionBy("col2").orderBy(*[when(condition, lit(1)).desc()])
df.select(["timestamp",
"col1",
"col2",
"col3"
]).withColumn("col4", when(condition, lead("col3",1).over(w))) \
.orderBy("timestamp") \
.show(500, truncate=False)
Apparently it won't lookup the "next" event properly. Any ideas on possible approaches?
A sample dataframe would be:
timestamp | col1 | col2 | col3 |
---|---|---|---|
2021-02-02 01:03:55 | s1 | null | null |
2021-02-02 01:04:16.952854 | s1 | other_ind | null |
2021-02-02 01:04:32.398155 | s1 | null | null |
2021-02-02 01:04:53.793089 | s1 | event_start_ind | event_1_value |
2021-02-02 01:05:10.936913 | s1 | null | null |
2021-02-02 01:05:36 | s1 | other_ind | null |
2021-02-02 01:05:42 | s1 | null | null |
2021-02-02 01:05:43 | s1 | null | null |
2021-02-02 01:05:44 | s1 | event_start_ind | event_2_value |
2021-02-02 01:05:46.623198 | s1 | null | null |
2021-02-02 01:06:50 | s1 | null | null |
2021-02-02 01:07:19.607685 | s1 | null | null |
The desired result would be:
timestamp | col1 | col2 | col3 | col4 |
---|---|---|---|---|
2021-02-02 01:03:55 | s1 | null | null | event_1_value |
2021-02-02 01:04:16.952854 | s1 | other_ind | null | event_1_value |
2021-02-02 01:04:32.398155 | s1 | null | null | event_1_value |
2021-02-02 01:04:53.793089 | s1 | event_start_ind | event_1_value | event_1_value |
2021-02-02 01:05:10.936913 | s1 | null | null | event_2_value |
2021-02-02 01:05:36 | s1 | other_ind | null | event_2_value |
2021-02-02 01:05:42 | s1 | null | null | event_2_value |
2021-02-02 01:05:43 | s1 | null | null | event_2_value |
2021-02-02 01:05:44 | s1 | event_start_ind | event_2_value | event_2_value |
2021-02-02 01:05:46.623198 | s1 | null | null | null |
2021-02-02 01:06:50 | s1 | null | null | null |
2021-02-02 01:07:19.607685 | s1 | null | null | null |
It looks like you don't have a partition for your window, and the events do not have the same amount of records. Considering this, the solution that comes to my mind is to use the position of each event start to retrieve the respective value.
Considering the sorting by timestamp, we extract the position of each line:
from pyspark.sql import Window
from pyspark.sql.functions import col, rank, collect_list, expr
df = (
spark.createDataFrame(
[
{ 'timestamp': '2021-02-02 01:03:55', 'col1': 's1' },
{ 'timestamp': '2021-02-02 01:04:16.952854', 'col1': 's1', 'col2': 'other_ind'},
{ 'timestamp': '2021-02-02 01:04:32.398155', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:04:53.793089', 'col1': 's1', 'col2': 'event_start_ind', 'col3': 'event_1_value'},
{ 'timestamp': '2021-02-02 01:05:10.936913', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:05:36', 'col1': 's1', 'col2': 'other_ind'},
{ 'timestamp': '2021-02-02 01:05:42', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:05:43', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:05:44', 'col1': 's1', 'col2': 'event_start_ind', 'col3': 'event_2_value'},
{ 'timestamp': '2021-02-02 01:05:46.623198', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:06:50', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:07:19.607685', 'col1': 's1'}
]
)
.withColumn('timestamp', col('timestamp').cast('timestamp'))
.withColumn("line", rank().over(Window.orderBy("timestamp")))
)
df.show(truncate=False)
+----+--------------------------+---------------+-------------+----+
|col1|timestamp |col2 |col3 |line|
+----+--------------------------+---------------+-------------+----+
|s1 |2021-02-02 01:03:55 |null |null |1 |
|s1 |2021-02-02 01:04:16.952854|other_ind |null |2 |
|s1 |2021-02-02 01:04:32.398155|null |null |3 |
|s1 |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|4 |
|s1 |2021-02-02 01:05:10.936913|null |null |5 |
|s1 |2021-02-02 01:05:36 |other_ind |null |6 |
|s1 |2021-02-02 01:05:42 |null |null |7 |
|s1 |2021-02-02 01:05:43 |null |null |8 |
|s1 |2021-02-02 01:05:44 |event_start_ind|event_2_value|9 |
|s1 |2021-02-02 01:05:46.623198|null |null |10 |
|s1 |2021-02-02 01:06:50 |null |null |11 |
|s1 |2021-02-02 01:07:19.607685|null |null |12 |
+----+--------------------------+---------------+-------------+----+
After that we identify each event start:
df_event_start = (
df.filter(col("col2") == 'event_start_ind')
.select(
col("line").alias("event_start_line"),
col("col3").alias("event_value")
)
)
df_event_start.show()
+----------------+-------------+
|event_start_line| event_value|
+----------------+-------------+
| 4|event_1_value|
| 9|event_2_value|
+----------------+-------------+
Uses event_start
information to find the next valid event start:
df_with_event_starts = (
df.join(
df_event_start.select(collect_list('event_start_line').alias("event_starts"))
)
.withColumn("next_valid_event", expr("element_at(filter(event_starts, x -> x >= line), 1)"))
)
df_with_event_starts.show(truncate=False)
+----+--------------------------+---------------+-------------+----+------------+----------------+
|col1|timestamp |col2 |col3 |line|event_starts|next_valid_event|
+----+--------------------------+---------------+-------------+----+------------+----------------+
|s1 |2021-02-02 01:03:55 |null |null |1 |[4, 9] |4 |
|s1 |2021-02-02 01:04:16.952854|other_ind |null |2 |[4, 9] |4 |
|s1 |2021-02-02 01:04:32.398155|null |null |3 |[4, 9] |4 |
|s1 |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|4 |[4, 9] |4 |
|s1 |2021-02-02 01:05:10.936913|null |null |5 |[4, 9] |9 |
|s1 |2021-02-02 01:05:36 |other_ind |null |6 |[4, 9] |9 |
|s1 |2021-02-02 01:05:42 |null |null |7 |[4, 9] |9 |
|s1 |2021-02-02 01:05:43 |null |null |8 |[4, 9] |9 |
|s1 |2021-02-02 01:05:44 |event_start_ind|event_2_value|9 |[4, 9] |9 |
|s1 |2021-02-02 01:05:46.623198|null |null |10 |[4, 9] |null |
|s1 |2021-02-02 01:06:50 |null |null |11 |[4, 9] |null |
|s1 |2021-02-02 01:07:19.607685|null |null |12 |[4, 9] |null |
+----+--------------------------+---------------+-------------+----+------------+----------------+
And finally retrieves the correct value:
(
df_with_event_starts.join(
df_event_start,
col("next_valid_event") == col("event_start_line"),
how="left"
)
.drop("line", "event_starts", "next_valid_event", "event_start_line")
.show(truncate=False)
)
+----+--------------------------+---------------+-------------+-------------+
|col1|timestamp |col2 |col3 |event_value |
+----+--------------------------+---------------+-------------+-------------+
|s1 |2021-02-02 01:03:55 |null |null |event_1_value|
|s1 |2021-02-02 01:04:16.952854|other_ind |null |event_1_value|
|s1 |2021-02-02 01:04:32.398155|null |null |event_1_value|
|s1 |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|event_1_value|
|s1 |2021-02-02 01:05:10.936913|null |null |event_2_value|
|s1 |2021-02-02 01:05:36 |other_ind |null |event_2_value|
|s1 |2021-02-02 01:05:42 |null |null |event_2_value|
|s1 |2021-02-02 01:05:43 |null |null |event_2_value|
|s1 |2021-02-02 01:05:44 |event_start_ind|event_2_value|event_2_value|
|s1 |2021-02-02 01:05:46.623198|null |null |null |
|s1 |2021-02-02 01:06:50 |null |null |null |
|s1 |2021-02-02 01:07:19.607685|null |null |null |
+----+--------------------------+---------------+-------------+-------------+
This solution will bring you problems in processing large volumes.
If you can figure out a key for each event, I advise you to continue with your initial solution using window functions. If this happens, you can test the last
or first
sql function (ignoring the null values).
Hopefully, someone will help you with a better solution.
Tip: Making the data frame creation scripts available in the question is helpful.