I am trying to get the sum of Revenue over the last 3 Month rows (excluding the current row) for each Client. Minimal example with current attempt in Databricks:
cols = ['Client','Month','Revenue']
df_pd = pd.DataFrame([['A',201701,100],
['A',201702,101],
['A',201703,102],
['A',201704,103],
['A',201705,104],
['B',201701,201],
['B',201702,np.nan],
['B',201703,203],
['B',201704,204],
['B',201705,205],
['B',201706,206],
['B',201707,207]
])
df_pd.columns = cols
spark_df = spark.createDataFrame(df_pd)
spark_df.createOrReplaceTempView('df_sql')
df_out = sqlContext.sql("""
select *, (sum(ifnull(Revenue,0)) over (partition by Client
order by Client,Month
rows between 3 preceding and 1 preceding)) as Total_Sum3
from df_sql
""")
df_out.show()
+------+------+-------+----------+
|Client| Month|Revenue|Total_Sum3|
+------+------+-------+----------+
| A|201701| 100.0| null|
| A|201702| 101.0| 100.0|
| A|201703| 102.0| 201.0|
| A|201704| 103.0| 303.0|
| A|201705| 104.0| 306.0|
| B|201701| 201.0| null|
| B|201702| NaN| 201.0|
| B|201703| 203.0| NaN|
| B|201704| 204.0| NaN|
| B|201705| 205.0| NaN|
| B|201706| 206.0| 612.0|
| B|201707| 207.0| 615.0|
+------+------+-------+----------+
As you can see, if a null value exists anywhere in the 3 month window, a null value is returned. I would like to treat nulls as 0, hence the ifnull attempt, but this does not seem to work. I have also tried a case statement to change NULL to 0, with no luck.
It is Apache Spark, my bad! (am working in Databricks and I thought it was MySQL under the hood). Is it too late to change the title?
@Barmar, you are right in that IFNULL()
doesn't treat NaN
as null
. I managed to figure out the fix thanks to @user6910411 from here: SO link. I had to change the numpy NaNs to spark nulls. The correct code from after the sample df_pd is created:
spark_df = spark.createDataFrame(df_pd)
from pyspark.sql.functions import isnan, col, when
#this converts all NaNs in numeric columns to null:
spark_df = spark_df.select([
when(~isnan(c), col(c)).alias(c) if t in ("double", "float") else c
for c, t in spark_df.dtypes])
spark_df.createOrReplaceTempView('df_sql')
df_out = sqlContext.sql("""
select *, (sum(ifnull(Revenue,0)) over (partition by Client
order by Client,Month
rows between 3 preceding and 1 preceding)) as Total_Sum3
from df_sql order by Client,Month
""")
df_out.show()
which then gives the desired:
+------+------+-------+----------+
|Client| Month|Revenue|Total_Sum3|
+------+------+-------+----------+
| A|201701| 100.0| null|
| A|201702| 101.0| 100.0|
| A|201703| 102.0| 201.0|
| A|201704| 103.0| 303.0|
| A|201705| 104.0| 306.0|
| B|201701| 201.0| null|
| B|201702| null| 201.0|
| B|201703| 203.0| 201.0|
| B|201704| 204.0| 404.0|
| B|201705| 205.0| 407.0|
| B|201706| 206.0| 612.0|
| B|201707| 207.0| 615.0|
+------+------+-------+----------+
Is sqlContext the best way to approach this or would it be better / more elegant to achieve the same result via pyspark.sql.window?