I have a Dataframe which look like the following.
+-------+---------+-----+-------------------+
|product| location| num| ts|
+-------+---------+-----+-------------------+
| 1| A| 3|2024-01-31 04:28:27|
| 1| B| 12|2024-01-31 04:28:27|
| 1| C| 19|2024-01-31 04:28:27|
| 1| D| 1|2024-01-31 04:28:27|
| 1| E| 4|2024-01-31 04:28:27|
| 1| D| 2|2024-02-02 04:30:08|
| 1| A| 4|2024-02-02 04:30:08|
| 1| C| 20|2024-02-02 04:30:08|
| 1| E| 1|2024-02-02 04:30:08|
| 1| F| 20|2024-02-02 04:30:08|
| 1| E| 2|2024-02-03 04:32:39|
| 1| D| 5|2024-02-03 04:32:39|
| 1| A| 3|2024-02-03 04:32:39|
+-------+---------+-----+-------------------+
Then I want to sum the product total of all visible location from the nearest time. I have try multiple ways of windowing but the nearest solution is only get the sum of product total with the same timestamp but missing out other location from nearest previous timestamp.
Here is an expected result, it should be calculate like below.
+-------+---------+-----+-------------------+-------+
|product| location| num| ts| total|
+-------+---------+-----+-------------------+-------+
| 1| A| 3|2024-01-31 04:28:27| 39| --> A(3) + B(12) + C(19) + D(1) + E(4)
| 1| B| 12|2024-01-31 04:28:27| 39| --> A(3) + B(12) + C(19) + D(1) + E(4)
| 1| C| 19|2024-01-31 04:28:27| 39| --> A(3) + B(12) + C(19) + D(1) + E(4)
| 1| D| 1|2024-01-31 04:28:27| 39| --> A(3) + B(12) + C(19) + D(1) + E(4)
| 1| E| 4|2024-01-31 04:28:27| 39| --> A(3) + B(12) + C(19) + D(1) + E(4)
| 1| D| 2|2024-02-02 04:30:08| 59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
| 1| A| 4|2024-02-02 04:30:08| 59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
| 1| C| 20|2024-02-02 04:30:08| 59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
| 1| E| 1|2024-02-02 04:30:08| 59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
| 1| F| 20|2024-02-02 04:30:08| 59| --> A(4) + B(12) + C(20) + D(2) + E(1) + F(20)
| 1| E| 2|2024-02-03 04:32:39| 62| --> A(3) + B(12) + C(20) + D(5) + E(2) + F(20)
| 1| D| 5|2024-02-03 04:32:39| 62| --> A(3) + B(12) + C(20) + D(5) + E(2) + F(20)
| 1| A| 3|2024-02-03 04:32:39| 62| --> A(3) + B(12) + C(20) + D(5) + E(2) + F(20)
+-------+---------+-----+-------------------+-------+
Is there a way to write SparkSQL or Pyspark to achieved this?
Appreciated!!
Check out this solution:
import pyspark.sql.functions as f
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from datetime import datetime
spark = SparkSession.builder.appName("pyspark_playground").getOrCreate()
df = spark.createDataFrame([
(1, 'A', 3, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'B', 12, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'C', 19, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'D', 1, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'E', 4, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'D', 2, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'A', 4, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'C', 20, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'E', 1, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'F', 20, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'E', 2, datetime(2024, 2, 3, 4, 32, 39)),
(1, 'D', 5, datetime(2024, 2, 3, 4, 32, 39)),
(1, 'A', 3, datetime(2024, 2, 3, 4, 32, 39)),
], ['product', 'location', 'num', 'ts'])
df_with_content_array = (
df
.withColumn('content_struct', f.struct('location', 'num', 'ts'))
.withColumn('content_array', f.collect_list('content_struct').over(Window.partitionBy('product').orderBy('ts').rowsBetween(Window.unboundedPreceding, Window.currentRow)))
.withColumn('content_array_size', f.size('content_array'))
.withColumn('largest_content_array', f.max('content_array_size').over(Window.partitionBy('product', 'ts')))
.where(f.col('content_array_size') == f.col('largest_content_array'))
)
df = (
df.alias('original')
.join(df_with_content_array.alias('aggregated'), ['product', 'ts'], 'inner')
.select('original.*', 'aggregated.content_array')
.withColumn('distinct_locations', f.array_distinct('content_array.location'))
.withColumn('content_distance', f.expr('transform(distinct_locations, x -> transform(filter(content_array, y -> y.location = x), z -> array(abs(cast(z.ts - ts as int)), z.num)))'))
.withColumn('closest_num', f.expr('transform(transform(content_distance, x -> filter(x, y -> y[0] = array_min(transform(x, z -> z[0])))[0]), t -> t[1])'))
.withColumn('sum_closest_num', f.expr('aggregate(closest_num, cast(0 as double), (acc, x) -> acc + x)'))
.select('product', 'location', 'num', 'ts', 'closest_num', 'sum_closest_num')
)
df.show(truncate=False)
and the output is:
+-------+--------+---+-------------------+---------------------+---------------+
|product|location|num|ts |closest_num |sum_closest_num|
+-------+--------+---+-------------------+---------------------+---------------+
|1 |A |3 |2024-01-31 04:28:27|[3, 12, 19, 1, 4] |39.0 |
|1 |B |12 |2024-01-31 04:28:27|[3, 12, 19, 1, 4] |39.0 |
|1 |C |19 |2024-01-31 04:28:27|[3, 12, 19, 1, 4] |39.0 |
|1 |D |1 |2024-01-31 04:28:27|[3, 12, 19, 1, 4] |39.0 |
|1 |E |4 |2024-01-31 04:28:27|[3, 12, 19, 1, 4] |39.0 |
|1 |D |2 |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0 |
|1 |A |4 |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0 |
|1 |C |20 |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0 |
|1 |E |1 |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0 |
|1 |F |20 |2024-02-02 04:30:08|[4, 12, 20, 2, 1, 20]|59.0 |
|1 |E |2 |2024-02-03 04:32:39|[3, 12, 20, 5, 2, 20]|62.0 |
|1 |D |5 |2024-02-03 04:32:39|[3, 12, 20, 5, 2, 20]|62.0 |
|1 |A |3 |2024-02-03 04:32:39|[3, 12, 20, 5, 2, 20]|62.0 |
+-------+--------+---+-------------------+---------------------+---------------+
My colleague and I found a different approach that we think is going to be more efficient:
import pyspark.sql.functions as f
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from datetime import datetime
spark = SparkSession.builder.appName("pyspark_playground").getOrCreate()
df = spark.createDataFrame([
(1, 'A', 3, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'B', 12, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'C', 19, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'D', 1, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'E', 4, datetime(2024, 1, 31, 4, 28, 27)),
(1, 'D', 2, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'A', 4, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'C', 20, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'E', 1, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'F', 20, datetime(2024, 2, 2, 4, 30, 8)),
(1, 'E', 2, datetime(2024, 2, 3, 4, 32, 39)),
(1, 'D', 5, datetime(2024, 2, 3, 4, 32, 39)),
(1, 'A', 3, datetime(2024, 2, 3, 4, 32, 39)),
], ['product', 'location', 'num', 'ts'])
df_pivoted = (
df
.groupBy('product', 'ts')
.pivot('location')
.agg(f.first('num'))
)
locations = set(df_pivoted.columns) - {'product', 'ts'}
for location in locations:
df_pivoted = df_pivoted.withColumn(location, f.last(location, True).over(Window.partitionBy('product').orderBy(f.col('ts')).rowsBetween(Window.unboundedPreceding, Window.currentRow)))
df_pivoted = (
df_pivoted
.withColumn('total', f.expr('+'.join([f'coalesce({location}, 0)' for location in locations])))
)
output = (
df
.join(df_pivoted, ['product', 'ts'], 'inner')
.select(df['*'], df_pivoted.total.alias('total'))
)
output.show()
and the output is:
+-------+--------+---+-------------------+-----+
|product|location|num| ts|total|
+-------+--------+---+-------------------+-----+
| 1| A| 3|2024-01-31 04:28:27| 39|
| 1| B| 12|2024-01-31 04:28:27| 39|
| 1| C| 19|2024-01-31 04:28:27| 39|
| 1| D| 1|2024-01-31 04:28:27| 39|
| 1| E| 4|2024-01-31 04:28:27| 39|
| 1| D| 2|2024-02-02 04:30:08| 59|
| 1| A| 4|2024-02-02 04:30:08| 59|
| 1| C| 20|2024-02-02 04:30:08| 59|
| 1| E| 1|2024-02-02 04:30:08| 59|
| 1| F| 20|2024-02-02 04:30:08| 59|
| 1| E| 2|2024-02-03 04:32:39| 62|
| 1| D| 5|2024-02-03 04:32:39| 62|
| 1| A| 3|2024-02-03 04:32:39| 62|
+-------+--------+---+-------------------+-----+