Search code examples
apache-sparkpysparkapache-spark-sqlwindow-functionsrolling-computation

pyspark high performance rolling/window aggregations on timeseries data


Basic Question

I have a dataset with ~10 billion rows. I'm looking for the most performant way to calculate rolling/windowed aggregates/metrics (sum, mean, min, max, stddev) over four different time windows (3 days, 7 days, 14 days, 21 days).

Spark/AWS EMR Specs

spark version: 2.4.4
ec2 instance type: r5.24xlarge
num core ec2 instances: 10
num pyspark partitions: 600

Overview

I read a bunch of SO posts that addressed either the mechanics of calculating rolling statistics or how to make Window functions faster. However, none of the posts combined these two concepts in a way that solves my problem. I've shown below a few options that do what I want but I need them to operate faster on my real dataset so I'm looking for suggestions that are faster/better.

My dataset is structured as follows but with ~10 billion rows:

+--------------------------+----+-----+
|date                      |name|value|
+--------------------------+----+-----+
|2020-12-20 17:45:19.536796|1   |5    |
|2020-12-21 17:45:19.53683 |1   |105  |
|2020-12-22 17:45:19.536846|1   |205  |
|2020-12-23 17:45:19.536861|1   |305  |
|2020-12-24 17:45:19.536875|1   |405  |
|2020-12-25 17:45:19.536891|1   |505  |
|2020-12-26 17:45:19.536906|1   |605  |
|2020-12-20 17:45:19.536796|2   |10   |
|2020-12-21 17:45:19.53683 |2   |110  |
|2020-12-22 17:45:19.536846|2   |210  |
|2020-12-23 17:45:19.536861|2   |310  |
|2020-12-24 17:45:19.536875|2   |410  |
|2020-12-25 17:45:19.536891|2   |510  |
|2020-12-26 17:45:19.536906|2   |610  |
|2020-12-20 17:45:19.536796|3   |15   |
|2020-12-21 17:45:19.53683 |3   |115  |
|2020-12-22 17:45:19.536846|3   |215  |

I need my dataset to look like below. Note: window statistics for a 7-day window are shown but I need three other windows as well.

+--------------------------+----+-----+----+-----+---+---+------------------+
|date                      |name|value|sum |mean |min|max|stddev            |
+--------------------------+----+-----+----+-----+---+---+------------------+
|2020-12-20 17:45:19.536796|1   |5    |5   |5.0  |5  |5  |NaN               |
|2020-12-21 17:45:19.53683 |1   |105  |110 |55.0 |5  |105|70.71067811865476 |
|2020-12-22 17:45:19.536846|1   |205  |315 |105.0|5  |205|100.0             |
|2020-12-23 17:45:19.536861|1   |305  |620 |155.0|5  |305|129.09944487358058|
|2020-12-24 17:45:19.536875|1   |405  |1025|205.0|5  |405|158.11388300841898|
|2020-12-25 17:45:19.536891|1   |505  |1530|255.0|5  |505|187.08286933869707|
|2020-12-26 17:45:19.536906|1   |605  |2135|305.0|5  |605|216.02468994692867|
|2020-12-20 17:45:19.536796|2   |10   |10  |10.0 |10 |10 |NaN               |
|2020-12-21 17:45:19.53683 |2   |110  |120 |60.0 |10 |110|70.71067811865476 |
|2020-12-22 17:45:19.536846|2   |210  |330 |110.0|10 |210|100.0             |
|2020-12-23 17:45:19.536861|2   |310  |640 |160.0|10 |310|129.09944487358058|
|2020-12-24 17:45:19.536875|2   |410  |1050|210.0|10 |410|158.11388300841898|
|2020-12-25 17:45:19.536891|2   |510  |1560|260.0|10 |510|187.08286933869707|
|2020-12-26 17:45:19.536906|2   |610  |2170|310.0|10 |610|216.02468994692867|
|2020-12-20 17:45:19.536796|3   |15   |15  |15.0 |15 |15 |NaN               |
|2020-12-21 17:45:19.53683 |3   |115  |130 |65.0 |15 |115|70.71067811865476 |
|2020-12-22 17:45:19.536846|3   |215  |345 |115.0|15 |215|100.0             |

Details

For ease of reading, I'll just do one window in these examples. Things I have tried:

  1. Basic Window().over() syntax
  2. Converting windowed values into an array column and using higher order functions
  3. Spark SQL

Setup

import datetime

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import FloatType
import pandas as pd
import numpy as np

spark = SparkSession.builder.appName('example').getOrCreate()

# create spark dataframe
n = 7
names = [1, 2, 3]
date_list = [datetime.datetime.today() - datetime.timedelta(days=(n-x)) for x in range(n)]
values = [x*100 for x in range(n)]

rows = []
for name in names:
    for d, v in zip(date_list, values):
        rows.append(
            {
                "name": name,
                "date": d,
                "value": v+(5*name)
            }
        )
df = spark.createDataFrame(data=rows)

# setup window
window_days = 7
window = (
    Window
    .partitionBy(F.col("name"))
    .orderBy(F.col("date").cast("timestamp").cast("long"))
    .rangeBetween(-window_days * 60 * 60 * 24 + 1, Window.currentRow)
)

1. Basic

This creates multiple window specs as shown here and is therefore performed in serial and runs very slowly on a large dataset

status_quo = (df
    .withColumn("sum",F.sum(F.col("value")).over(window))
    .withColumn("mean",F.avg(F.col("value")).over(window))
    .withColumn("min",F.min(F.col("value")).over(window))
    .withColumn("max",F.max(F.col("value")).over(window))
    .withColumn("stddev",F.stddev(F.col("value")).over(window))
)
status_quo.show()
status_quo.explain()

2. Array Column --> Higher Order Functions

Per this answer seems to create fewer window specs, but the aggregate() function syntax makes no sense to me, I don't know how to write stddev using higher order functions, and the performance doesn't seem much better in small tests

@F.udf(returnType=FloatType())
def array_stddev(row_value):
    """
    temporary function since I don't know how to write higher order standard deviation
    """
    return np.std(row_value, dtype=float).tolist()

# 1. collect window into array column
# 2. use higher order (array) functions to calculate aggregations over array (window values)
# Question: how to write standard deviation in aggregate()
hof_example = (
    df
    .withColumn("value_array", F.collect_list(F.col("value")).over(window))
    .withColumn("sum_example", F.expr('AGGREGATE(value_array, DOUBLE(0), (acc, x) -> acc + x)'))
    .withColumn("mean_example", F.expr('AGGREGATE(value_array, DOUBLE(0), (acc, x) -> acc + x, acc -> acc / size(value_array))'))
    .withColumn("max_example", F.array_max(F.col("value_array")))
    .withColumn("min_example", F.array_min(F.col("value_array")))
    .withColumn("std_example", array_stddev(F.col("value_array")))
)

3. Spark SQL

This appears to be the fastest in simple tests. The only (minor) issue is the rest of my codebase uses the DataFrame API. Seems faster in small tests but not tested on full dataset.

df.createOrReplaceTempView("df")
sql_example = spark.sql(
    """
    SELECT 
        *
        , sum(value)
        OVER (
            PARTITION BY name
            ORDER BY CAST(date AS timestamp) 
            RANGE BETWEEN INTERVAL 7 DAYS PRECEDING AND CURRENT ROW
        ) AS sum
        , mean(value)
        OVER (
            PARTITION BY name
            ORDER BY CAST(date AS timestamp) 
            RANGE BETWEEN INTERVAL 7 DAYS PRECEDING AND CURRENT ROW
        ) AS mean
        , min(value)
        OVER (
            PARTITION BY name
            ORDER BY CAST(date AS timestamp) 
            RANGE BETWEEN INTERVAL 7 DAYS PRECEDING AND CURRENT ROW
        ) AS min
        , max(value)
        OVER (
            PARTITION BY name
            ORDER BY CAST(date AS timestamp) 
            RANGE BETWEEN INTERVAL 7 DAYS PRECEDING AND CURRENT ROW
        ) AS max
        , stddev(value)
        OVER (
            PARTITION BY name
            ORDER BY CAST(date AS timestamp) 
            RANGE BETWEEN INTERVAL 7 DAYS PRECEDING AND CURRENT ROW
        ) AS stddev
    FROM df"""
)

Solution

  • NOTE: I'm going to mark this as the accepted answer for the time being. If someone finds a faster/better please notify me and I'll switch it!

    EDIT Clarification: The calculations shown here assume input dataframes pre-processed to the day day-level with day-level rolling calculations

    After I posted the question I tested several different options on my real dataset (and got some input from coworkers) and I believe the fastest way to do this (for large datasets) uses pyspark.sql.functions.window() with groupby().agg instead of pyspark.sql.window.Window().

    A similar answer can be found here

    The steps to make this work are:

    1. sort dataframe by name and date (in example dataframe)
    2. .persist() dataframe
    3. Compute grouped dataframe using F.window() and join back to df for every window required.

    The best/easiest way to see this in action is on the SQL diagram in the Spark GUI thing. When Window() is used, the SQL execution is totally sequential. However, when F.window() is used, the diagram shows parallelization! NOTE: on small datasets Window() still seems faster.

    In my tests with real data on 7-day windows, Window() was 3-5x slower than F.window(). The only downside is F.window() is a bit less convenient to use. I've shown some code and screenshots below for reference

    Fastest Solution Found (F.window() with groupby.agg())

    # this turned out to be super important for tricking spark into parallelizing things
    df = df.orderBy("name", "date")
    df.persist()
    
    fwindow7 = F.window(
        F.col("date"),
        windowDuration="7 days",
        slideDuration="1 days",
    ).alias("window")
    
    gb7 = (
        df
        .groupBy(F.col("name"), fwindow7)
        .agg(
            F.sum(F.col("value")).alias("sum7"),
            F.avg(F.col("value")).alias("mean7"),
            F.min(F.col("value")).alias("min7"),
            F.max(F.col("value")).alias("max7"),
            F.stddev(F.col("value")).alias("stddev7"),
            F.count(F.col("value")).alias("cnt7")
        )
        .withColumn("date", F.date_sub(F.col("window.end").cast("date"), 1))
        .drop("window")
    )
    window_function_example = df.join(gb7, ["name", "date"], how="left")
    
    
    fwindow14 = F.window(
        F.col("date"),
        windowDuration="14 days",
        slideDuration="1 days",
    ).alias("window")
    
    gb14 = (
        df
        .groupBy(F.col("name"), fwindow14)
        .agg(
            F.sum(F.col("value")).alias("sum14"),
            F.avg(F.col("value")).alias("mean14"),
            F.min(F.col("value")).alias("min14"),
            F.max(F.col("value")).alias("max14"),
            F.stddev(F.col("value")).alias("stddev14"),
            F.count(F.col("value")).alias("cnt14")
        )
        .withColumn("date", F.date_sub(F.col("window.end").cast("date"), 1))
        .drop("window")
    )
    window_function_example = window_function_example.join(gb14, ["name", "date"], how="left")
    window_function_example.orderBy("name", "date").show(truncate=True)
    

    SQL Diagram

    Group By

    Option 2 from Original Question (Higher Order Functions applied to Window())

    window7 = (
        Window
        .partitionBy(F.col("name"))
        .orderBy(F.col("date").cast("timestamp").cast("long"))
        .rangeBetween(-7 * 60 * 60 * 24 + 1, Window.currentRow)
    )
    window14 = (
        Window
        .partitionBy(F.col("name"))
        .orderBy(F.col("date").cast("timestamp").cast("long"))
        .rangeBetween(-14 * 60 * 60 * 24 + 1, Window.currentRow)
    )
    hof_example = (
        df
        .withColumn("value_array", F.collect_list(F.col("value")).over(window7))
        .withColumn("sum7", F.expr('AGGREGATE(value_array, DOUBLE(0), (acc, x) -> acc + x)'))
        .withColumn("mean7", F.expr('AGGREGATE(value_array, DOUBLE(0), (acc, x) -> acc + x, acc -> acc / size(value_array))'))
        .withColumn("max7", F.array_max(F.col("value_array")))
        .withColumn("min7", F.array_min(F.col("value_array")))
        .withColumn("std7", F.expr('AGGREGATE(value_array, DOUBLE(0), (acc, x) -> acc + (x - mean7)*(x - mean7), acc -> sqrt(acc / (size(value_array) - 1)))'))
        .withColumn("count7", F.size(F.col("value_array")))
        .drop("value_array")
    )
    hof_example = (
        hof_example
        .withColumn("value_array", F.collect_list(F.col("value")).over(window14))
        .withColumn("sum14", F.expr('AGGREGATE(value_array, DOUBLE(0), (acc, x) -> acc + x)'))
        .withColumn("mean14", F.expr('AGGREGATE(value_array, DOUBLE(0), (acc, x) -> acc + x, acc -> acc / size(value_array))'))
        .withColumn("max14", F.array_max(F.col("value_array")))
        .withColumn("min14", F.array_min(F.col("value_array")))
        .withColumn("std14", F.expr('AGGREGATE(value_array, DOUBLE(0), (acc, x) -> acc + (x - mean14)*(x - mean14), acc -> sqrt(acc / (size(value_array) - 1)))'))
        .withColumn("count14", F.size(F.col("value_array")))
        .drop("value_array")
    )
    
    hof_example.show(truncate=True)
    

    SQL Diagram Snippet

    Higher Order Functions