I'm trying to translate the below pandas code to PySpark. But I'm having trouble with these two points:
I didn't find anything good in the documentation. If you have a hint, I'll be really grateful!
df.set_index('var1', inplace=True)
df['varGrouped'] = df.groupby(level=0)['var2'].min()
df.reset_index(inplace=True)
pandas_df.groupby(level=0)
would group the pandas_df
by the first index field (in case of multiindex data). Since there is only 1 index field based on the provided code, your code is a simple group by the var1
field. The same can be replicated in pyspark with a groupBy()
and taking the min
of var2
.
However, the aggregation result is stored in a new column within the same dataframe. So, the number of rows don't depreciate. This can be replicated by using min
window function.
import pyspark.sql.functions as func
from pyspark.sql.window import Window as wd
data_sdf. \
withColumn('grouped_var', func.min('var2').over(wd.partitionBy('var1')))
withColumn
helps you add/replace columns.
Here's an example using sample data.
data_sdf.show()
# +---+---+
# | a| b|
# +---+---+
# | 1| 2|
# | 1| 3|
# | 2| 5|
# | 2| 4|
# +---+---+
data_sdf. \
withColumn('grouped_res', func.min('b').over(wd.partitionBy('a'))). \
show()
# +---+---+-----------+
# | a| b|grouped_res|
# +---+---+-----------+
# | 1| 2| 2|
# | 1| 3| 2|
# | 2| 5| 4|
# | 2| 4| 4|
# +---+---+-----------+