Search code examples
pythondataframepysparkapache-spark-sqlcalculated-columns

Rolling Sum on Spark Dataframe filtered by other column


I have a spark dataframe looking like this, containing a single row for every combination of articlenumber, countrycode and date where a value in amount exists for that combination. There are about 400,000 rows in this dataframe.

articlenumber   countrycode   date         amount
--------------------------------------------------
4421-222-222    DE            2020-02-05   200
1234-567-890    EN            2019-05-23   42
1345-457-456    EN            2019-12-12   107

Now I need an additional column "amount 12M" that calculates its value for each row according to the following rules:

In every row, "amount 12M" should contain the sum of all values from 'amount' where articlenumber and countrycode match to the ones from that specific row and where the date lies in between the 12 months before the date in that row.

Do I need to add rows with the amount 0 for date/country/articlenumber combinations that haven't got a value yet?

As I'm not quite an expert in programming (engineering student) I need some help how to achieve this within the python script which deals with that dataframe.

Thank you for any ideas on this.


Solution

  • Edited:

    import pyspark.sql.functions as f
    from pyspark.sql import Window
    
    w = Window.partitionBy('articlenumber', 'countrycode').orderBy('date').orderBy('yearmonth').rangeBetween(-11, 0)
    
    df.withColumn('yearmonth', f.expr('(year(date) - 2000) * 12 + month(date)')) \
      .withColumn('amount 12M', f.sum('amount').over(w)) \
      .orderBy('date').show(10, False)
    
    +-------------+-----------+----------+------+---------+----------+
    |articlenumber|countrycode|date      |amount|yearmonth|amount 12M|
    +-------------+-----------+----------+------+---------+----------+
    |4421-222-222 |DE         |2019-02-05|100   |230      |100       |
    |4421-222-222 |DE         |2019-03-01|50    |231      |150       |
    |1234-567-890 |EN         |2019-05-23|42    |233      |42        |
    |1345-457-456 |EN         |2019-12-12|107   |240      |107       |
    |4421-222-222 |DE         |2020-02-05|200   |242      |250       |
    +-------------+-----------+----------+------+---------+----------+
    
    

    I am not sure about the exact 12 months but this will work.

    import pyspark.sql.functions as f
    from pyspark.sql import Window
    
    w = Window.partitionBy('articlenumber', 'countrycode').orderBy('unix_date').rangeBetween(- 365 * 86400, 0)
    
    df.withColumn('unix_date', f.unix_timestamp('date', 'yyyy-MM-dd')) \
      .withColumn('amount 12M', f.sum('amount').over(w)) \
      .orderBy('date').show(10, False)
    
    +-------------+-----------+----------+------+----------+----------+
    |articlenumber|countrycode|date      |amount|unix_date |amount 12M|
    +-------------+-----------+----------+------+----------+----------+
    |4421-222-222 |DE         |2019-02-05|100   |1549324800|100       |
    |4421-222-222 |DE         |2019-02-06|50    |1549411200|150       |
    |1234-567-890 |EN         |2019-05-23|42    |1558569600|42        |
    |1345-457-456 |EN         |2019-12-12|107   |1576108800|107       |
    |4421-222-222 |DE         |2020-02-05|200   |1580860800|350       |
    +-------------+-----------+----------+------+----------+----------+