Search code examples
pythonapache-sparkpysparkaws-glue

PySpark using window to create field using previous created field value


I am trying to create a new column in my df called indexCP using Window I want to take the previous value from indexCP * (current_df['return']+1) if there is no previous indexCP do 100 * (current_df['return']+1).

column_list = ["id","secname"]
windowval = (Window.partitionBy(column_list).orderBy(col('calendarday').cast("timestamp").cast("long")).rangeBetween(Window.unboundedPreceding, 0))
spark_df = spark_df.withColumn('indexCP', when(spark_df["PreviousYearUnique"] == spark_df["yearUnique"], 100 * (current_df['return']+1)).otherwise(last('indexCP').over(windowval) * (current_df['return']+1)))

when I run the above code I get an error "AnalysisException: "cannot resolve 'indexCP' given input columns:" which I believe is saying you cant take a value that has not been created yet but I am unsure of how to fix it.

Starting Data Frame
## +---+-----------+----------+------------------+       
## | id|calendarday|   secName|            return|
## +---+-----------+----------+------------------+
## |  1|2015-01-01 |         1|            0.0076|
## |  1|2015-01-02 |         1|            0.0026|
## |  1|2015-01-01 |         2|            0.0016|
## |  1|2015-01-02 |         2|            0.0006|
## |  2|2015-01-01 |         3|            0.0012|
## |  2|2015-01-02 |         3|            0.0014|
## +---+----------+-----------+------------------+

New Data Frame IndexCP added
## +---+-----------+--------+---------+------------+       
## | id|calendarday| secName|   return|     IndexCP|
## +---+-----------+--------+---------+------------+
## |  1|2015-01-01 |       1|   0.0076|      100.76|(1st 100*(return+1))
## |  1|2015-01-02 |       1|   0.0026|  101.021976|(2nd 100.76*(return+1))
## |  2|2015-01-01 |       2|   0.0016|      100.16|(1st 100*(return+1))
## |  2|2015-01-02 |       2|   0.0006|  100.220096|(2nd 100.16*(return+1))
## |  3|2015-01-01 |       3|   0.0012|     100.12 |(1st 100*(return+1))
## |  3|2015-01-02 |       3|   0.0014|  100.260168|(2nd 100.12*(return+1))
## +---+----------+---------+---------+------------+

Solution

  • EDIT: This should be the final answer, I've extended it by another row for secName column.

    What you're looking for is a rolling product function using your formula of IndexCP * (current_return + 1). First you need to aggregate all existing returns into an ArrayType and then aggregate. This can be done with some Spark SQL aggregate function, such as:

    column_list = ["id","secname"]
    windowval = (
        Window.partitionBy(column_list)
          .orderBy(f.col('calendarday').cast("timestamp"))
          .rangeBetween(Window.unboundedPreceding, 0)
    )
    
    
    df1.show()
    +---+-----------+-------+------+
    | id|calendarday|secName|return|
    +---+-----------+-------+------+
    |  1| 2015-01-01|      1|0.0076|
    |  1| 2015-01-02|      1|0.0026|
    |  1| 2015-01-03|      1|0.0014|
    |  2| 2015-01-01|      2|0.0016|
    |  2| 2015-01-02|      2|6.0E-4|
    |  2| 2015-01-03|      2|   0.0|
    |  3| 2015-01-01|      3|0.0012|
    |  3| 2015-01-02|      3|0.0014|
    +---+-----------+-------+------+
    
    # f.collect_list(...) gets all your returns - this must be windowed
    # cast(1 as double) is your base of 1 to begin with
    # (acc, x) -> acc * (1 + x) is your formula translated to Spark SQL
    # where acc is the accumulated value and x is the incoming value
    df1.withColumn(
        "rolling_returns", 
        f.collect_list("return").over(windowval)
    ).withColumn("IndexCP", 
        100 * f.expr("""
        aggregate(
           rolling_returns,
           cast(1 as double),
           (acc, x) -> acc * (1+x))
        """)
    ).orderBy("id", "calendarday").show(truncate=False)
    
    +---+-----------+-------+------+------------------------+------------------+
    |id |calendarday|secName|return|rolling_returns         |IndexCP           |
    +---+-----------+-------+------+------------------------+------------------+
    |1  |2015-01-01 |1      |0.0076|[0.0076]                |100.76            |
    |1  |2015-01-02 |1      |0.0026|[0.0076, 0.0026]        |101.021976        |
    |1  |2015-01-03 |1      |0.0014|[0.0076, 0.0026, 0.0014]|101.16340676640002|
    |2  |2015-01-01 |2      |0.0016|[0.0016]                |100.16000000000001|
    |2  |2015-01-02 |2      |6.0E-4|[0.0016, 6.0E-4]        |100.220096        |
    |2  |2015-01-03 |2      |0.0   |[0.0016, 6.0E-4, 0.0]   |100.220096        |
    |3  |2015-01-01 |3      |0.0012|[0.0012]                |100.12            |
    |3  |2015-01-02 |3      |0.0014|[0.0012, 0.0014]        |100.26016800000002|
    +---+-----------+-------+------+------------------------+------------------+
    

    Explanation: The starting value must be 1 and the multiplier of 100 must be on the outside of the expression, otherwise you indeed start drifting by a factor of 100 above expected returns.

    I have verified the values now adhere to your formula, for instance for secName == 1 and id == 1:

    100 * ((1.0026 * (0.0076 + 1)) * (0.0014 + 1)) = 101.1634067664
    

    Which is indeed correct according to the formula (acc, x) -> acc * (1+x). Hope this helps!