I have the code below, and I want to rewrite the Pandas UDF to pure window functions in pyspark for speed optimization
Column cumulative_pass
is what I want to create programmatically -
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql import Window
import sys
spark_session = SparkSession.builder.getOrCreate()
df_data = {'username': ['bob','bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob'],
'session': [1,2,3,4,5,6,7,8],
'year_start': [2020,2020,2020,2020,2020,2021,2022,2023],
'year_end': [2020,2020,2020,2020,2021,2021,2022,2023],
'pass': [1,0,0,0,0,1,1,0],
'cumulative_pass': [0,0,0,0,0,1,2,3],
}
df_pandas = pd.DataFrame.from_dict(df_data)
df = spark_session.createDataFrame(df_pandas)
df.show()
the last show
will be this -
+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
| bob| 1| 2020| 2020| 1| 0|
| bob| 2| 2020| 2020| 0| 0|
| bob| 3| 2020| 2020| 0| 0|
| bob| 4| 2020| 2020| 0| 0|
| bob| 5| 2020| 2021| 0| 0|
| bob| 6| 2021| 2021| 1| 1|
| bob| 7| 2022| 2022| 1| 2|
| bob| 8| 2023| 2023| 0| 3|
+--------+-------+----------+--------+----+---------------+
Code below works but is slow (UDFs are slow)
def conditional_sum(data: pd.DataFrame) -> int:
df = data.apply(pd.Series)
return df.loc[df['year_start'].max() > df['year_end']]['pass'].sum()
udf_conditional_sum = F.pandas_udf(conditional_sum, IntegerType())
w = Window.partitionBy("username").orderBy(F.asc("year_start")).rowsBetween(-sys.maxsize, 0)
df = df.withColumn("calculate_cumulative_pass", udf_conditional_sum(F.struct("year_start", "year_end", "pass")).over(w))
note- I modified w
sligthly, and removed a second sort
W = Window.partitionBy('username').orderBy('year_start')
df = (
df
.withColumn('cumulative_pass', F.collect_list(F.struct('year_end', 'pass')).over(W))
.withColumn('cumulative_pass', F.expr("AGGREGATE(cumulative_pass, 0, (acc, x) -> CAST(acc + IF(x['year_end'] < year_start, x['pass'], 0) AS INT))"))
)
Create a window specification and collect the pairs of year_end
and pass
values for all the previous rows. Aggregate the pairs and sum
the pass
values in the pairs when the year_end
in the pair is less than the year_start
of the current row.
+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
|bob |1 |2020 |2020 |1 |0 |
|bob |2 |2020 |2020 |0 |0 |
|bob |3 |2020 |2020 |0 |0 |
|bob |4 |2020 |2020 |0 |0 |
|bob |5 |2020 |2021 |0 |0 |
|bob |6 |2021 |2021 |1 |1 |
|bob |7 |2022 |2022 |1 |2 |
|bob |8 |2023 |2023 |0 |3 |
+--------+-------+----------+--------+----+---------------+