Search code examples
pysparkapache-spark-sqlpyspark-pandas

Index with groupby PySpark


I'm trying to translate the below pandas code to PySpark. But I'm having trouble with these two points:

  • But there is index in Spark DataFrame?
  • How can I group in level=0 like that?

I didn't find anything good in the documentation. If you have a hint, I'll be really grateful!

df.set_index('var1', inplace=True)
df['varGrouped'] = df.groupby(level=0)['var2'].min()
df.reset_index(inplace=True)

Solution

  • pandas_df.groupby(level=0) would group the pandas_df by the first index field (in case of multiindex data). Since there is only 1 index field based on the provided code, your code is a simple group by the var1 field. The same can be replicated in pyspark with a groupBy() and taking the min of var2.

    However, the aggregation result is stored in a new column within the same dataframe. So, the number of rows don't depreciate. This can be replicated by using min window function.

    import pyspark.sql.functions as func
    from pyspark.sql.window import Window as wd
    
    data_sdf. \
        withColumn('grouped_var', func.min('var2').over(wd.partitionBy('var1')))
    

    withColumn helps you add/replace columns.


    Here's an example using sample data.

    data_sdf.show()
    
    # +---+---+
    # |  a|  b|
    # +---+---+
    # |  1|  2|
    # |  1|  3|
    # |  2|  5|
    # |  2|  4|
    # +---+---+
    
    data_sdf. \
        withColumn('grouped_res', func.min('b').over(wd.partitionBy('a'))). \
        show()
    
    # +---+---+-----------+
    # |  a|  b|grouped_res|
    # +---+---+-----------+
    # |  1|  2|          2|
    # |  1|  3|          2|
    # |  2|  5|          4|
    # |  2|  4|          4|
    # +---+---+-----------+