I have a spark dataframe looking like this, containing a single row for every combination of articlenumber, countrycode and date where a value in amount exists for that combination. There are about 400,000 rows in this dataframe.
articlenumber countrycode date amount
--------------------------------------------------
4421-222-222 DE 2020-02-05 200
1234-567-890 EN 2019-05-23 42
1345-457-456 EN 2019-12-12 107
Now I need an additional column "amount 12M" that calculates its value for each row according to the following rules:
In every row, "amount 12M" should contain the sum of all values from 'amount' where articlenumber and countrycode match to the ones from that specific row and where the date lies in between the 12 months before the date in that row.
Do I need to add rows with the amount 0 for date/country/articlenumber combinations that haven't got a value yet?
As I'm not quite an expert in programming (engineering student) I need some help how to achieve this within the python script which deals with that dataframe.
Thank you for any ideas on this.
Edited:
import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.partitionBy('articlenumber', 'countrycode').orderBy('date').orderBy('yearmonth').rangeBetween(-11, 0)
df.withColumn('yearmonth', f.expr('(year(date) - 2000) * 12 + month(date)')) \
.withColumn('amount 12M', f.sum('amount').over(w)) \
.orderBy('date').show(10, False)
+-------------+-----------+----------+------+---------+----------+
|articlenumber|countrycode|date |amount|yearmonth|amount 12M|
+-------------+-----------+----------+------+---------+----------+
|4421-222-222 |DE |2019-02-05|100 |230 |100 |
|4421-222-222 |DE |2019-03-01|50 |231 |150 |
|1234-567-890 |EN |2019-05-23|42 |233 |42 |
|1345-457-456 |EN |2019-12-12|107 |240 |107 |
|4421-222-222 |DE |2020-02-05|200 |242 |250 |
+-------------+-----------+----------+------+---------+----------+
I am not sure about the exact 12 months but this will work.
import pyspark.sql.functions as f
from pyspark.sql import Window
w = Window.partitionBy('articlenumber', 'countrycode').orderBy('unix_date').rangeBetween(- 365 * 86400, 0)
df.withColumn('unix_date', f.unix_timestamp('date', 'yyyy-MM-dd')) \
.withColumn('amount 12M', f.sum('amount').over(w)) \
.orderBy('date').show(10, False)
+-------------+-----------+----------+------+----------+----------+
|articlenumber|countrycode|date |amount|unix_date |amount 12M|
+-------------+-----------+----------+------+----------+----------+
|4421-222-222 |DE |2019-02-05|100 |1549324800|100 |
|4421-222-222 |DE |2019-02-06|50 |1549411200|150 |
|1234-567-890 |EN |2019-05-23|42 |1558569600|42 |
|1345-457-456 |EN |2019-12-12|107 |1576108800|107 |
|4421-222-222 |DE |2020-02-05|200 |1580860800|350 |
+-------------+-----------+----------+------+----------+----------+