Search code examples
pythonpyspark

Issues with mean and groupby using pyspark


I have an issue using mean together with groupBy (or window) on a pyspark dataframe. The code below returns a dataframe called mean_df, for the columns mean_change1, the values is NaN, NaN, NaN.

I don't understand why this is the case. Is it due to the presence of NaN in df?

import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
import pyspark.sql.functions as F

np.random.seed(10)

data = pd.DataFrame({'ID':list(range(1,101)) + list(range(1,101)) + list(range(1,101)),
                     'date':[pd.to_datetime('2021-07-30')]*100 + [pd.to_datetime('2022-12-31')] * 100 + [pd.to_datetime('2023-04-30')]*100,
                     'value1': [i for i in np.random.normal(98, 4.5, 100)] + [np.nan] *3 + [i for i in np.random.normal(100, 5, 97)] + [i for i in np.random.normal(120, 8, 95)] +[5.3, 9000, 160, -5222, 158],
                     'value2':[i for i in np.random.normal(52, 11, 100)] + [i for i in np.random.normal(50, 10, 100)] + [i for i in np.random.normal(49, 10, 100)]
             })

spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame(data)

# Calculate change

window = Window.partitionBy("ID").orderBy("date")
df = df.withColumn("previous_value1", F.lag("value1", 1).over(window))
df = df.withColumn("change1", df["value1"] - df["previous_value1"])

mean_df = df.groupBy("date").agg(F.mean("change1").alias("mean_change1"))

mean_df.toPandas()


Solution

  • you could replace the NaN with nulls, and then calculate mean.

    window = wd.partitionBy("id").orderBy("date")
    data1_sdf = data_sdf. \
        replace(np.nan, None). \
        withColumn("previous_value1", func.lag("value1", 1).over(window)). \
        withColumn("change1", func.col("value1") - func.col("previous_value1"))
    
    data1_sdf. \
        groupBy("date"). \
        agg(func.sum('change1').alias("sum_change1"), 
            func.count('change1').alias("cnt_change1"), 
            func.mean('change1').alias("mean_change1")
            ). \
        show()
    
    # +-------------------+------------------+-----------+------------------+
    # |               date|       sum_change1|cnt_change1|      mean_change1|
    # +-------------------+------------------+-----------+------------------+
    # |2023-04-30 00:00:00| 5355.116077984317|         97|55.207382247260995|
    # |2021-07-30 00:00:00|              null|          0|              null|
    # |2022-12-31 00:00:00|164.43277052556064|         97|1.6951832012944397|
    # +-------------------+------------------+-----------+------------------+