Search code examples
pythonapache-sparkpysparkgroup-bygrouping

PySpark equivalent to a groupby categories in pandas?


On Pandas, we can group by a categorical series and then when aggregating, it displays all the categories, regardless it contains any records or not.

import pandas as pd

df = pd.DataFrame({"Age": [12, 20, 40, 60, 72]}, dtype=np.float64)
cuts = pd.cut(df.Age, bins=[0, 11, 30, 60])
df.Age.groupby(cuts).agg(mean="mean", occurrences="size")

#           mean  occurrences
# Age                        
# (0, 11]    NaN            0
# (11, 30]  16.0            2
# (30, 60]  50.0            2

As you can see, the first bin is displayed even though it does not appear in the dataset. How could I achieve the same behaviour on PySpark?


Solution

  • The following is quite much, but I'm not aware of any nicer method.

    from pyspark.sql import functions as F
    df = spark.createDataFrame([(12,), (20,), (40,), (60,), (72,)], ['Age'])
    
    bins = [0, 11, 30, 60]
    
    conds = F
    for i, b in enumerate(bins[1:]):
        conds = conds.when(F.col('id') <= b, f'({bins[i]}, {b}]')
    df2 = spark.range(1, bins[-1]+1).withColumn('_grp', conds)
    
    df = df2.join(df, df2.id == df.Age, 'left')
    df = df.groupBy(F.col('_grp').alias('Age')).agg(
        F.mean('Age').alias('mean'),
        F.count('Age').alias('occurrences'),
    )
    
    df.show()
    # +--------+----+-----------+
    # |     Age|mean|occurrences|
    # +--------+----+-----------+
    # |(11, 30]|16.0|          2|
    # | (0, 11]|null|          0|
    # |(30, 60]|50.0|          2|
    # +--------+----+-----------+