Search code examples
dataframeapache-sparkpysparkapache-spark-sql

SparkSQL sum each row with the nearest item preceding


I have a Dataframe which look like the following.

+-------+---------+-----+-------------------+
|product| location|  num|                 ts|
+-------+---------+-----+-------------------+
|      1|        A|    3|2024-01-31 04:28:27|
|      1|        B|   12|2024-01-31 04:28:27|
|      1|        C|   19|2024-01-31 04:28:27|
|      1|        D|    1|2024-01-31 04:28:27|
|      1|        E|    4|2024-01-31 04:28:27|
|      1|        D|    2|2024-02-02 04:30:08|
|      1|        A|    4|2024-02-02 04:30:08|
|      1|        C|   20|2024-02-02 04:30:08|
|      1|        E|    1|2024-02-02 04:30:08|
|      1|        F|   20|2024-02-02 04:30:08|
|      1|        E|    2|2024-02-03 04:32:39|
|      1|        D|    5|2024-02-03 04:32:39|
|      1|        A|    3|2024-02-03 04:32:39|
+-------+---------+-----+-------------------+

Then I want to sum the product total of all visible location from the nearest time. I have try multiple ways of windowing but the nearest solution is only get the sum of product total with the same timestamp but missing out other location from nearest previous timestamp.

Here is an expected result, it should be calculate like below.

+-------+---------+-----+-------------------+-------+
|product| location|  num|                 ts|  total|
+-------+---------+-----+-------------------+-------+
|      1|        A|    3|2024-01-31 04:28:27|     39| --> A(3) + B(12) + C(19) + D(1) + E(4)
|      1|        B|   12|2024-01-31 04:28:27|     39| --> A(3) + B(12) + C(19) + D(1) + E(4)
|      1|        C|   19|2024-01-31 04:28:27|     39| --> A(3) + B(12) + C(19) + D(1) + E(4)
|      1|        D|    1|2024-01-31 04:28:27|     39| --> A(3) + B(12) + C(19) + D(1) + E(4)
|      1|        E|    4|2024-01-31 04:28:27|     39| --> A(3) + B(12) + C(19) + D(1) + E(4)
|      1|        D|    2|2024-02-02 04:30:08|     59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
|      1|        A|    4|2024-02-02 04:30:08|     59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
|      1|        C|   20|2024-02-02 04:30:08|     59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
|      1|        E|    1|2024-02-02 04:30:08|     59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
|      1|        F|   20|2024-02-02 04:30:08|     59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
|      1|        E|    2|2024-02-03 04:32:39|     62| --> A(3) + B(12) + C(20) + D(5) + E(2) + F(20)
|      1|        D|    5|2024-02-03 04:32:39|     62| --> A(3) + B(12) + C(20) + D(5) + E(2) + F(20)
|      1|        A|    3|2024-02-03 04:32:39|     62| --> A(3) + B(12) + C(20) + D(5) + E(2) + F(20)
+-------+---------+-----+-------------------+-------+

Is there a way to write SparkSQL or Pyspark to achieved this?

Appreciated!!


Solution

  • Check out this solution:

    import pyspark.sql.functions as f
    from pyspark.sql.types import *
    from pyspark.sql import SparkSession
    from pyspark.sql.window import Window
    from datetime import datetime
    
    spark = SparkSession.builder.appName("pyspark_playground").getOrCreate()
    
    df = spark.createDataFrame([
        (1, 'A', 3, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'B', 12, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'C', 19, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'D', 1, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'E', 4, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'D', 2, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'A', 4, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'C', 20, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'E', 1, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'F', 20, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'E', 2, datetime(2024, 2, 3, 4, 32, 39)),
        (1, 'D', 5, datetime(2024, 2, 3, 4, 32, 39)),
        (1, 'A', 3, datetime(2024, 2, 3, 4, 32, 39)),
    ], ['product', 'location', 'num', 'ts'])
    
    df_with_content_array = (
        df
        .withColumn('content_struct', f.struct('location', 'num', 'ts'))
        .withColumn('content_array', f.collect_list('content_struct').over(Window.partitionBy('product').orderBy('ts').rowsBetween(Window.unboundedPreceding, Window.currentRow)))
        .withColumn('content_array_size', f.size('content_array'))
        .withColumn('largest_content_array', f.max('content_array_size').over(Window.partitionBy('product', 'ts')))
        .where(f.col('content_array_size') == f.col('largest_content_array'))
    )
    
    df = (
        df.alias('original')
        .join(df_with_content_array.alias('aggregated'), ['product', 'ts'], 'inner')
        .select('original.*', 'aggregated.content_array')
        .withColumn('distinct_locations', f.array_distinct('content_array.location'))
        .withColumn('content_distance', f.expr('transform(distinct_locations, x -> transform(filter(content_array, y -> y.location = x), z -> array(abs(cast(z.ts - ts as int)), z.num)))'))
        .withColumn('closest_num', f.expr('transform(transform(content_distance, x -> filter(x, y -> y[0] = array_min(transform(x, z -> z[0])))[0]), t -> t[1])'))
        .withColumn('sum_closest_num', f.expr('aggregate(closest_num, cast(0 as double), (acc, x) -> acc + x)'))
        .select('product', 'location', 'num', 'ts', 'closest_num', 'sum_closest_num')
    )
    
    df.show(truncate=False)
    

    and the output is:

    +-------+--------+---+-------------------+---------------------+---------------+
    |product|location|num|ts                 |closest_num          |sum_closest_num|
    +-------+--------+---+-------------------+---------------------+---------------+
    |1      |A       |3  |2024-01-31 04:28:27|[3, 12, 19, 1, 4]    |39.0           |
    |1      |B       |12 |2024-01-31 04:28:27|[3, 12, 19, 1, 4]    |39.0           |
    |1      |C       |19 |2024-01-31 04:28:27|[3, 12, 19, 1, 4]    |39.0           |
    |1      |D       |1  |2024-01-31 04:28:27|[3, 12, 19, 1, 4]    |39.0           |
    |1      |E       |4  |2024-01-31 04:28:27|[3, 12, 19, 1, 4]    |39.0           |
    |1      |D       |2  |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0           |
    |1      |A       |4  |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0           |
    |1      |C       |20 |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0           |
    |1      |E       |1  |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0           |
    |1      |F       |20 |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0           |
    |1      |E       |2  |2024-02-03 04:32:39|[3, 12, 20, 5, 2, 20]|62.0           |
    |1      |D       |5  |2024-02-03 04:32:39|[3, 12, 20, 5, 2, 20]|62.0           |
    |1      |A       |3  |2024-02-03 04:32:39|[3, 12, 20, 5, 2, 20]|62.0           |
    +-------+--------+---+-------------------+---------------------+---------------+
    

    My colleague and I found a different approach that we think is going to be more efficient:

    import pyspark.sql.functions as f
    from pyspark.sql.types import *
    from pyspark.sql import SparkSession
    from pyspark.sql.window import Window
    from datetime import datetime
    
    spark = SparkSession.builder.appName("pyspark_playground").getOrCreate()
    
    df = spark.createDataFrame([
        (1, 'A', 3, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'B', 12, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'C', 19, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'D', 1, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'E', 4, datetime(2024, 1, 31, 4, 28, 27)),
        (1, 'D', 2, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'A', 4, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'C', 20, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'E', 1, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'F', 20, datetime(2024, 2, 2, 4, 30, 8)),
        (1, 'E', 2, datetime(2024, 2, 3, 4, 32, 39)),
        (1, 'D', 5, datetime(2024, 2, 3, 4, 32, 39)),
        (1, 'A', 3, datetime(2024, 2, 3, 4, 32, 39)),
    ], ['product', 'location', 'num', 'ts'])
    
    
    df_pivoted = (
        df
        .groupBy('product', 'ts')
        .pivot('location')
        .agg(f.first('num'))
    )
    locations = set(df_pivoted.columns) - {'product', 'ts'} 
    for location in locations:
        df_pivoted = df_pivoted.withColumn(location, f.last(location, True).over(Window.partitionBy('product').orderBy(f.col('ts')).rowsBetween(Window.unboundedPreceding, Window.currentRow)))
    df_pivoted = (
        df_pivoted
        .withColumn('total', f.expr('+'.join([f'coalesce({location}, 0)' for location in locations])))
    )
    output = (
        df
        .join(df_pivoted, ['product', 'ts'], 'inner')
        .select(df['*'], df_pivoted.total.alias('total'))
    )
    
    output.show()
    

    and the output is:

    +-------+--------+---+-------------------+-----+                                
    |product|location|num|                 ts|total|
    +-------+--------+---+-------------------+-----+
    |      1|       A|  3|2024-01-31 04:28:27|   39|
    |      1|       B| 12|2024-01-31 04:28:27|   39|
    |      1|       C| 19|2024-01-31 04:28:27|   39|
    |      1|       D|  1|2024-01-31 04:28:27|   39|
    |      1|       E|  4|2024-01-31 04:28:27|   39|
    |      1|       D|  2|2024-02-02 04:30:08|   59|
    |      1|       A|  4|2024-02-02 04:30:08|   59|
    |      1|       C| 20|2024-02-02 04:30:08|   59|
    |      1|       E|  1|2024-02-02 04:30:08|   59|
    |      1|       F| 20|2024-02-02 04:30:08|   59|
    |      1|       E|  2|2024-02-03 04:32:39|   62|
    |      1|       D|  5|2024-02-03 04:32:39|   62|
    |      1|       A|  3|2024-02-03 04:32:39|   62|
    +-------+--------+---+-------------------+-----+