In PySpark Databricks
, I'm trying to recalculate column A in dataframe
using following recursive formula:
A1 = A0 + A0 * B0 / 100
A2 = A1 + A1 * B1 / 100
...
Initial Table
Column A | Column B |
---|---|
3740 | -15 |
3740 | -5 |
3740 | -10 |
(unlimited depth)
Result
Column A | Column B |
---|---|
3740 | -15 |
3179 | -5 |
3020.05 | -10 |
(unlimited depth)
Maybe someone has an idea?
Tried lag function, but it doesn't support recursive calculation, and I can't find a way to bypass it. Row-by-row processing takes endless amount of time because of the data amount.
Since your column B is variable, you get the list of it and use any one of the functions like below and call as udf function.
def calculate_nth_term(A0, n):
An = float(A0)
for i in range(n):
An *= (1 + B[i]/100)
return An
or
def calculate_nth_term(A0, n):
if n==0:
return float(A0)
else:
An = calculate_nth_term(A0,n-1)*(1 + B[n-1]/100)
return An
Here, B is the list taken from the dataframe.
code:
from pyspark.sql.types import FloatType
from pyspark.sql.functions import col, udf, row_number,lit
B = df.select("B").rdd.map(lambda x : x.B).collect()
calculate_An_udf = udf(calculate_nth_term, FloatType())
windowSpec = Window().orderBy("A")
tmp_df = df.withColumn("n", row_number().over(windowSpec)-lit(1))
tmp_df.withColumn("An",calculate_An_udf(col("A"),col("n"))).show()
This calculates results in recursive way.
output:
A | B | n | An |
---|---|---|---|
3740 | -15 | 0 | 3740 |
3740 | -5 | 1 | 3179 |
3740 | -10 | 2 | 3020.05 |