Search code examples
pythonapache-sparkpysparkapache-spark-sql

PySpark: CumSum with Salting over Window w/ Skew


How can I use salting to perform a cumulative sum window operation? While a tiny sample, my id column is heavily skewed, and I need to perform effectively this operation on it:

window_unsalted = Window.partitionBy("id").orderBy("timestamp")  

# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))

However, I want to try salting because at the scale of my data, I cannot compute it otherwise.

Consider this MWE. How can I replicate the expected value, 20, using salting techniques?

from pyspark.sql import functions as F  
from pyspark.sql.window import Window  

data = [  
    (7329, 1636617182, 1.0),  
    (7329, 1636142065, 1.0),  
    (7329, 1636142003, 1.0),  
    (7329, 1680400388, 1.0),  
    (7329, 1636142400, 1.0),  
    (7329, 1636397030, 1.0),  
    (7329, 1636142926, 1.0),  
    (7329, 1635970969, 1.0),  
    (7329, 1636122419, 1.0),  
    (7329, 1636142195, 1.0),  
    (7329, 1636142654, 1.0),  
    (7329, 1636142484, 1.0),  
    (7329, 1636119628, 1.0),  
    (7329, 1636404275, 1.0),  
    (7329, 1680827925, 1.0),  
    (7329, 1636413478, 1.0),  
    (7329, 1636143578, 1.0),  
    (7329, 1636413800, 1.0),  
    (7329, 1636124556, 1.0),  
    (7329, 1636143614, 1.0),  
    (7329, 1636617778, -1.0),  
    (7329, 1636142155, -1.0),  
    (7329, 1636142061, -1.0),  
    (7329, 1680400415, -1.0),  
    (7329, 1636142480, -1.0),  
    (7329, 1636400183, -1.0),  
    (7329, 1636143444, -1.0),  
    (7329, 1635977251, -1.0),  
    (7329, 1636122624, -1.0),  
    (7329, 1636142298, -1.0),  
    (7329, 1636142720, -1.0),  
    (7329, 1636142584, -1.0),  
    (7329, 1636122147, -1.0),  
    (7329, 1636413382, -1.0),  
    (7329, 1680827958, -1.0),  
    (7329, 1636413538, -1.0),  
    (7329, 1636143610, -1.0),  
    (7329, 1636414011, -1.0),  
    (7329, 1636141936, -1.0),  
    (7329, 1636146843, -1.0)  
]  
  
df = spark.createDataFrame(data, ["id", "timestamp", "value"])  
  
# Define the number of salt buckets  
num_buckets = 100  
  
# Add a salted_id column to the dataframe  
df = df.withColumn("salted_id", (F.concat(F.col("id"),   
                (F.rand(seed=42)*num_buckets).cast("int")).cast("string")))  
  
# Define a window partitioned by the salted_id, and ordered by timestamp  
window = Window.partitionBy("salted_id").orderBy("timestamp")  
  
# Add a cumulative sum column  
df = df.withColumn("cumulative_sum", F.sum("value").over(window))  
  
# Define a window partitioned by the original id, and ordered by timestamp  
window_unsalted = Window.partitionBy("id").orderBy("timestamp")  
  
# Compute the final cumulative sum by adding up the cumulative sums within each original id  
df = df.withColumn("final_cumulative_sum",   
                   F.sum("cumulative_sum").over(window_unsalted))  

# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))

# incorrect trial
df.agg(F.sum('final_cumulative_sum')).show()

# expected value
df.agg(F.sum('Expected')).show()

Solution

  • From what I see, the main issue here is that the timestamps must remain ordered for partial cumulative sums to be correct, e.g., if the sequence is 1,2,3 then 2 cannot go into different partition than 1 and 3.

    My suggestion is to use salt value based on timestamp that preserves the ordering. This will not completely remove skew, but you will still be able to partition within the same id:

    df = spark.createDataFrame(data, ["id", "timestamp", "value"])
    
    bucket_size = 10000  # the actual size will depend on timestamp distribution
    
    # Add timestamp-based salt column to the dataframe
    df = df.withColumn("salt", F.floor(F.col("timestamp") / F.lit(bucket_size)))
    
    # Get partial cumulative sums
    window_salted = Window.partitionBy("id", "salt").orderBy("timestamp")
    df = df.withColumn("cumulative_sum", F.sum("value").over(window_salted))
    
    # Get partial cumulative sums from previous windows
    df2 = df.groupby("id", "salt").agg(F.sum("value").alias("cumulative_sum_last"))
    window_full = Window.partitionBy("id").orderBy("salt")
    df2 = df2.withColumn("previous_sum", F.lag("cumulative_sum_last", default=0).over(window_full))
    df2 = df2.withColumn("previous_cumulative_sum", F.sum("previous_sum").over(window_full))
    
    # Join previous partial cumulative sums with original data
    df = df.join(df2, ["id", "salt"])  # maybe F.broadcast(df2) if it is small enough
    
    # Increase each cumulative sum value by final value of the previous window
    df = df.withColumn('final_cumulative_sum', F.col('cumulative_sum') + F.col('previous_cumulative_sum'))
    
    # expected value
    window_unsalted = Window.partitionBy("id").orderBy("timestamp")
    df = df.withColumn("Expected", F.sum('value').over(window_unsalted))
    
    # new calculation
    df.agg(F.sum('final_cumulative_sum')).show()
    
    # expected value
    df.agg(F.sum('Expected')).show()