Search code examples
pythonpysparkazure-databricks

PySpark cumulative lag logic


In PySpark Databricks, I'm trying to recalculate column A in dataframe using following recursive formula:

A1 = A0 + A0 * B0 / 100    
A2 = A1 + A1 * B1 / 100

...

Initial Table

Column A Column B
3740 -15
3740 -5
3740 -10

(unlimited depth)

Result

Column A Column B
3740 -15
3179 -5
3020.05 -10

(unlimited depth)

Maybe someone has an idea?

Tried lag function, but it doesn't support recursive calculation, and I can't find a way to bypass it. Row-by-row processing takes endless amount of time because of the data amount.


Solution

  • Since your column B is variable, you get the list of it and use any one of the functions like below and call as udf function.

    def calculate_nth_term(A0, n):
        An = float(A0)
        for i in range(n):
            An *= (1 + B[i]/100)
        return An
    

    or

    def calculate_nth_term(A0, n):
        if n==0:
            return float(A0)
        else:
            An = calculate_nth_term(A0,n-1)*(1 + B[n-1]/100)
            return An
    

    Here, B is the list taken from the dataframe.

    code:

    from pyspark.sql.types import FloatType
    from pyspark.sql.functions import col, udf, row_number,lit
    
    B = df.select("B").rdd.map(lambda x : x.B).collect()
    
    calculate_An_udf = udf(calculate_nth_term, FloatType())
    
    windowSpec = Window().orderBy("A")
    tmp_df = df.withColumn("n", row_number().over(windowSpec)-lit(1))
    tmp_df.withColumn("An",calculate_An_udf(col("A"),col("n"))).show()
    

    This calculates results in recursive way.

    output:

    A B n An
    3740 -15 0 3740
    3740 -5 1 3179
    3740 -10 2 3020.05