How to use salting technique for Skewed Aggregation in Pyspark.
Say we have Skewed data like below how to create salting column and use it in aggregation.
city | state | count |
---|---|---|
Lachung | Sikkim | 3,000 |
Rangpo | Sikkim | 50,000 |
Gangtok | Sikkim | 3,00,000 |
Bangalore | Karnataka | 2,50,00,000 |
Mumbai | Maharashtra | 2,90,00,000 |
To use the salting technique on skewed data, we need to create a column say "salt". Generate a random no with a range from 0 to (spark.sql.shuffle.partitions - 1).
Table should look like below, where "salt" column will have value from 0 to 199 (as in this case partitions size is 200). Now you can use groupBy on "city", "state", "salt".
city | state | salt |
---|---|---|
Lachung | Sikkim | 151 |
Lachung | Sikkim | 102 |
Lachung | Sikkim | 16 |
Rangpo | Sikkim | 5 |
Rangpo | Sikkim | 19 |
Rangpo | Sikkim | 16 |
Rangpo | Sikkim | 102 |
Gangtok | Sikkim | 55 |
Gangtok | Sikkim | 119 |
Gangtok | Sikkim | 16 |
Gangtok | Sikkim | 10 |
Bangalore | Karnataka | 19 |
Mumbai | Maharashtra | 0 |
Bangalore | Karnataka | 199 |
Mumbai | Maharashtra | 190 |
code:
from pyspark.sql import SparkSession, functions as f
from pyspark.sql.types import (
StructType, StructField, IntegerType
)
salval = f.round(f.rand() * int(spark.conf.get("spark.sql.shuffle.partitions")) -1 )
record_df.withColumn("salt", f.lit(salval).cast(IntegerType()))\
.groupBy("city", "state", "salt")\
.agg(
f.count("city")
)\
.drop("salt")
output:
city | state | count |
---|---|---|
Lachung | Sikkim | 3,000 |
Rangpo | Sikkim | 50,000 |
Gangtok | Sikkim | 3,00,000 |
Bangalore | Karnataka | 2,50,00,000 |
Mumbai | Maharashtra | 2,90,00,000 |