Search code examples
pythondataframeapache-sparkpyspark

pyspark - perform a cumulative sum over a partition based on a conditional statement


example code. Column cumulative_pass is what I want to create programmatically -

import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql import Window
import sys 

spark_session = SparkSession.builder.getOrCreate()
 

df_data = {'username': ['bob','bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob'],
           'session': [1,2,3,4,5,6,7,8],
           'year_start': [2020,2020,2020,2020,2020,2021,2022,2023],
           'year_end': [2020,2020,2020,2020,2021,2021,2022,2023],
           'pass': [1,0,0,0,0,1,1,0],
           'cumulative_pass': [0,0,0,0,0,1,2,3],
          }
df_pandas = pd.DataFrame.from_dict(df_data)


df = spark_session.createDataFrame(df_pandas)
df.show()

the last show will be this -

+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
|     bob|      1|      2020|    2020|   1|              0|
|     bob|      2|      2020|    2020|   0|              0|
|     bob|      3|      2020|    2020|   0|              0|
|     bob|      4|      2020|    2020|   0|              0|
|     bob|      5|      2020|    2021|   0|              0|
|     bob|      6|      2021|    2021|   1|              1|
|     bob|      7|      2022|    2022|   1|              2|
|     bob|      8|      2023|    2023|   0|              3|
+--------+-------+----------+--------+----+---------------+

the cumulative_pass column will sum pass column for all prior rows when year_start of current row is > the year_end of prior rows

my attempt (not working due to syntax) -

def conditional_sum(data: pd.DataFrame) -> int:
   # df = data.apply(pd.Series)  # transform dict into separate columns

    return df.loc[df['year_start'].max() > df['year_end']]['pass'].sum()

udf_conditional_sum = F.pandas_udf(conditional_sum, IntegerType())

w = Window.partitionBy("username").orderBy(F.asc("year_start"), F.desc("year_end")).rowsBetween(-sys.maxsize, -1)
df = df.withColumn("calculate_cumulative_pass", udf_conditional_sum(F.struct("year_start", "year_end", "pass")).over(w))

code based on https://stackoverflow.com/a/73278159/5004050


Solution

  • Perform a self merge on the dataframe so that year_start in the left dataframe is greater than year_end in the right dataframe then group the resulting dataframe by the columns in the left dataframe and agg pass with SUM to get the desired cumulative sum.

    df.createOrReplaceTempView('df')
    df1 = spark.sql(
    """
    SELECT
        A.username, A.session, 
        A.year_start, A.year_end, A.pass, 
        COALESCE(SUM(B.pass), 0) AS cummulative_sum
    FROM
        df AS A
    LEFT JOIN
        df AS B
    ON
        A.year_start > B.year_end
    GROUP BY
        A.username, A.session, 
        A.year_start, A.year_end, A.pass
    """)
    

    df1.show()
    +--------+-------+----------+--------+----+---------------+
    |username|session|year_start|year_end|pass|cummulative_sum|
    +--------+-------+----------+--------+----+---------------+
    |     bob|      1|      2020|    2020|   1|              0|
    |     bob|      2|      2020|    2020|   0|              0|
    |     bob|      3|      2020|    2020|   0|              0|
    |     bob|      4|      2020|    2020|   0|              0|
    |     bob|      5|      2020|    2021|   0|              0|
    |     bob|      6|      2021|    2021|   1|              1|
    |     bob|      7|      2022|    2022|   1|              2|
    |     bob|      8|      2023|    2023|   0|              3|
    +--------+-------+----------+--------+----+---------------+