How can I use salting to perform a cumulative sum window operation? While a tiny sample, my id column is heavily skewed, and I need to perform effectively this operation on it:
window_unsalted = Window.partitionBy("id").orderBy("timestamp")
# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))
However, I want to try salting because at the scale of my data, I cannot compute it otherwise.
Consider this MWE. How can I replicate the expected value, 20, using salting techniques?
from pyspark.sql import functions as F
from pyspark.sql.window import Window
data = [
(7329, 1636617182, 1.0),
(7329, 1636142065, 1.0),
(7329, 1636142003, 1.0),
(7329, 1680400388, 1.0),
(7329, 1636142400, 1.0),
(7329, 1636397030, 1.0),
(7329, 1636142926, 1.0),
(7329, 1635970969, 1.0),
(7329, 1636122419, 1.0),
(7329, 1636142195, 1.0),
(7329, 1636142654, 1.0),
(7329, 1636142484, 1.0),
(7329, 1636119628, 1.0),
(7329, 1636404275, 1.0),
(7329, 1680827925, 1.0),
(7329, 1636413478, 1.0),
(7329, 1636143578, 1.0),
(7329, 1636413800, 1.0),
(7329, 1636124556, 1.0),
(7329, 1636143614, 1.0),
(7329, 1636617778, -1.0),
(7329, 1636142155, -1.0),
(7329, 1636142061, -1.0),
(7329, 1680400415, -1.0),
(7329, 1636142480, -1.0),
(7329, 1636400183, -1.0),
(7329, 1636143444, -1.0),
(7329, 1635977251, -1.0),
(7329, 1636122624, -1.0),
(7329, 1636142298, -1.0),
(7329, 1636142720, -1.0),
(7329, 1636142584, -1.0),
(7329, 1636122147, -1.0),
(7329, 1636413382, -1.0),
(7329, 1680827958, -1.0),
(7329, 1636413538, -1.0),
(7329, 1636143610, -1.0),
(7329, 1636414011, -1.0),
(7329, 1636141936, -1.0),
(7329, 1636146843, -1.0)
]
df = spark.createDataFrame(data, ["id", "timestamp", "value"])
# Define the number of salt buckets
num_buckets = 100
# Add a salted_id column to the dataframe
df = df.withColumn("salted_id", (F.concat(F.col("id"),
(F.rand(seed=42)*num_buckets).cast("int")).cast("string")))
# Define a window partitioned by the salted_id, and ordered by timestamp
window = Window.partitionBy("salted_id").orderBy("timestamp")
# Add a cumulative sum column
df = df.withColumn("cumulative_sum", F.sum("value").over(window))
# Define a window partitioned by the original id, and ordered by timestamp
window_unsalted = Window.partitionBy("id").orderBy("timestamp")
# Compute the final cumulative sum by adding up the cumulative sums within each original id
df = df.withColumn("final_cumulative_sum",
F.sum("cumulative_sum").over(window_unsalted))
# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))
# incorrect trial
df.agg(F.sum('final_cumulative_sum')).show()
# expected value
df.agg(F.sum('Expected')).show()
From what I see, the main issue here is that the timestamps must remain ordered for partial cumulative sums to be correct, e.g., if the sequence is 1,2,3 then 2 cannot go into different partition than 1 and 3.
My suggestion is to use salt value based on timestamp that preserves the ordering. This will not completely remove skew, but you will still be able to partition within the same id
:
df = spark.createDataFrame(data, ["id", "timestamp", "value"])
bucket_size = 10000 # the actual size will depend on timestamp distribution
# Add timestamp-based salt column to the dataframe
df = df.withColumn("salt", F.floor(F.col("timestamp") / F.lit(bucket_size)))
# Get partial cumulative sums
window_salted = Window.partitionBy("id", "salt").orderBy("timestamp")
df = df.withColumn("cumulative_sum", F.sum("value").over(window_salted))
# Get partial cumulative sums from previous windows
df2 = df.groupby("id", "salt").agg(F.sum("value").alias("cumulative_sum_last"))
window_full = Window.partitionBy("id").orderBy("salt")
df2 = df2.withColumn("previous_sum", F.lag("cumulative_sum_last", default=0).over(window_full))
df2 = df2.withColumn("previous_cumulative_sum", F.sum("previous_sum").over(window_full))
# Join previous partial cumulative sums with original data
df = df.join(df2, ["id", "salt"]) # maybe F.broadcast(df2) if it is small enough
# Increase each cumulative sum value by final value of the previous window
df = df.withColumn('final_cumulative_sum', F.col('cumulative_sum') + F.col('previous_cumulative_sum'))
# expected value
window_unsalted = Window.partitionBy("id").orderBy("timestamp")
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))
# new calculation
df.agg(F.sum('final_cumulative_sum')).show()
# expected value
df.agg(F.sum('Expected')).show()