I have a PySpark DataFrame with 3 columns: 'client', 'product', 'date'. I want to run a groupBy operation:
df.groupBy("product", "date").agg(F.countDistinct("client"))
So I want to count the number of clients that bought a product in each day. This is causing huge skew data (in fact, it causes error because of memory). I have been learning about salting techniques. As I understood, it can be used with 'sum' or 'count' adding a new column to the groupBy and performing a second aggregation, but I do not see how to apply them in this case because of the countDistinct
aggregation method.
How can I do apply it in this case?
I would recommend to just not use countDistinct
at all here and achieve what you want using 2 aggregations in a row especially since you have a skew in your data. It can look like the following:
import pyspark.sql.functions as F
new_df = (df
.groupBy("product", "date", "client")
.agg({}) # getting unique ("product", "date", "client") tuples
.groupBy("product", "date")
.agg(F.count('*').alias('clients'))
)
First aggregation here ensures that you have a DataFrame with one row per each distinct ("product", "date", "client") tuple, second is counting number of clients for each ("product", "date") pair. This way you don't need to worry about skews anymore since Spark will know to do partial aggregations for you (as opposed to countDistinct
which is forced to send all individual "client" values corresponding to each ("product", "date") pair to one node).