Search code examples
apache-sparkpysparkdata-processingincremental-search

How to track number of distinct values incrementally from a spark table?


Suppose we have a very large table that we'd like to process statistics for incrementally.

Date Amount Customer
2022-12-20 30 Mary
2022-12-21 12 Mary
2022-12-20 12 Bob
2022-12-21 15 Bob
2022-12-22 15 Alice

We'd like to be able to calculate incrementally how much we made per distinct customer for a date range. So from 12-20 to 12-22 (inclusive), we'd have 3 distinct customers, but 12-20 to 12-21 there are 2 distinct customers.

If we want to run this pipeline once a day and there are many customers, how can we keep a rolling count of distinct customers for an arbitrary date range? Is there a way to do this without storing a huge list of customer names for each day?

We'd like to support a frontend that has a date range filter and can quickly calculate results for that date range. For example:

Start Date End Date Average Income Per Customer
2022-12-20 2022-12-21 (30+12+12+15)/2 = 34.5
2022-12-20 2022-12-22 (30+12+12+15+15)/3 = 28

The only approach I can think of is to store a set of customer names for each day, and when viewing the results calculate the size of the joined set of sets to calculate distinct customers. This seems inefficient. In this case we'd store the following table, with the customer column being extremely large.

Date Total Income Customers
2022-12-20 42 set(Mary, Bob)
2022-12-21 27 set(Mary, Bob)
2022-12-22 15 set(Alice)

Solution

  • For me the best solution is to do some pre calculations for the existing data, then for the new data that come everyday, do the caclulation only on new data, and add the results to the previous calclulated data, also do partitioning on date column as we filter on dates, this will trigger spark push down filters and accelerate your queries.

    There's 2 approach: one to get the sum amount between 2 dates, and other for the distinct customers between 2 dates:

    • For amout use prefix sum by adding the sum of all previous days to the last day, then to get the difference between the 2 dates you can just substract these 2 days only without looping all dates between.

    For distinct customers, the best approach I can think of is to save the date and customer columns in a new file, and partition by dates, that would help to optimize the queries, then use the fast approx_count_distinct.

    Here's some code:

    spark = SparkSession.builder.master("local[*]").getOrCreate()
    data = [
        ["2022-12-20", 30, "Mary"],
        ["2022-12-21", 12, "Mary"],
        ["2022-12-20", 12, "Bob"],
        ["2022-12-21", 15, "Bob"],
        ["2022-12-22", 15, "Alice"],
    ]
    df = spark.createDataFrame(data).toDF("Date", "Amount", "Customer")
    
    def init_amout_data(df):
        w = Window.orderBy(col("Date"))
        amount_sum_df = df.groupby("Date").agg(sum("Amount").alias("Amount")) \
            .withColumn("amout_sum", sum(col("Amount")).over(w)) \
            .withColumn("prev_amout_sum", lag("amout_sum", 1, 0).over(w)).select("Date", "amout_sum", "prev_amout_sum")
        amount_sum_df.write.mode("overwrite").partitionBy("Date").parquet("./path/amount_data_df")
        amount_sum_df.show(truncate=False)
    
    # keep only customer data to avoid unecessary data when querying, partitioning by Date will make query faster due to spark filter push down mechanism
    def init_customers_data(df):
        df.select("Date", "Customer").write.mode("overwrite").partitionBy("Date").parquet("./path/customers_data_df")
    
    # each day update the amount data dataframe (example at midnight), with only yesterday data: by talking the last amout_sum and adding to it the amount of the last day
    def update_amount_data(last_partition):
        amountDataDf = spark.read.parquet("./path/amount_data_df")
        maxDate = getMaxDate("./path/amount_data_df")  # implement a hadoop method to get the last partition date
        lastMaxPartition = amountDataDf.filter(col("date") == maxDate)
        lastPartitionAmountSum = lastMaxPartition.select("amout_sum").first.getLong(0)
        yesterday_amount_sum = last_partition.groupby("Date").agg(sum("Amount").alias("amount_sum"))
        newPartition = yesterday_amount_sum.withColumn("amount_sum", col("amount_sum") + lastPartitionAmountSum) \
            .withColumn("prev_amout_sum", lit(lastPartitionAmountSum))
        newPartition.write.mode("append").partitionBy("Date").parquet("./path/amount_data_df")
    
    def update_cusomers_data(last_partition):
        last_partition.write.mode("append").partitionBy("Date").parquet("./path/customers_data_df")
    
    def query_amount_date(beginDate, endDate):
        amountDataDf = spark.read.parquet("./path/amount_data_df")
        endDateAmount = amountDataDf.filter(col("Date") == endDate).select("amout_sum").first.getLong(0)
        beginDateDf = amountDataDf.filter(col("Date") == beginDate).select("prev_amout_sum").first.getLong(0)
        diff_amount = endDateAmount - beginDateDf
        return diff_amount
    
    def query_customers_date(beginDate, endDate):
        customersDataDf = spark.read.parquet("./path/customers_data_df")
        distinct_customers_nb = customersDataDf.filter(col("date").between(lit(beginDate), lit(endDate))) \
            .agg(approx_count_distinct(df.Customer).alias('distinct_customers')).first.getLong(0)
        return distinct_customers_nb
    
    # This is should be executed the first time only
    init_amout_data(df)
    init_customers_data(df)
    # This is should be executed everyday at midnight with data of the last day only
    last_day_partition = df.filter(col("date") == yesterday_date)
    update_amount_data(last_day_partition)
    update_cusomers_data(last_day_partition)
    # Optimized queries that should be executed with
    beginDate = "2022-12-20"
    endDate = "2022-12-22"
    answer = query_amount_date(beginDate, endDate) / query_customers_date(beginDate, endDate)
    print(answer)
    

    If calculating the distinct customer is not fast enough, there's another approach using the same pre sum calculation of all distinct customers and another table for distinct customer, each day if there's a new customer increment the first table and add that customer to the second table, if not don't do anything.

    Finally there are some tricks for optimizing the goupBy or window functions using salting oo extended partitioning.