I have below tables -
Audit ID | Customer ID | Budget | TimeStamp |
---|---|---|---|
1 | 123 | 100 | 2023-05-01 07:40:56 |
2 | 456 | 70 | 2023-05-01 12:20:50 |
3 | 456 | 70 | 2023-05-01 17:30:50 |
Audit ID | Product ID | Price | TimeStamp |
---|---|---|---|
5 | 5556 | 5 | 2023-05-01 06:40:56 |
6 | 5556 | 90 | 2023-05-01 06:40:56 |
7 | 7778 | 20 | 2023-05-01 12:20:50 |
9 | 7987 | 60 | 2023-05-01 05:50:00 |
10 | 7987 | 50 | 2023-05-04 05:50:00 |
Customer ID | Product ID |
---|---|
123 | 5556 |
123 | 7987 |
456 | 7778 |
456 | 7987 |
Problem statement - Find count where customer budget is more than product price (pick latest product price before customer budget timestamp) and also the max delta between customer budget and product price.
Basically i need query equivalent of below python code for PySpark , i ran below code on pandas and it worked fine for small dataset but for large data set pandas is not able to process it. I came across PySpark and read that its faster but it seems we cannot write nested loop in pyspark.
count_intances_budget_more_than_price = 0;
map_customer_id_max_delta = {}
processed_product_for_customer = new set()
for cusomter_row in customer_dataset:
max_delta = 0
if customer_id in map_customer_id_max_delta:
max_delta = map_customer_id_max_delta.get(customer_id)
for product_row in product_dataset:
if product_id in map_customer_id_product_id[customer_id]:
if product_id not in processed_product_for_customer:
processed_product_for_customer.add(product_id)
if product_timestamp < customer_timestamp and product_price < customer_budget:
count_intances_budget_more_than_price +=1
max_delta = max(max_delta,customer_budget - product_price )
map_customer_id_max_delta[customer_id] = max_delta
processed_product_for_customer.clear()
I think you just need to join the 3 tables and aggregate on customer-id and count the number of matched products and calculate the max difference for each customer
Input:
from datetime import datetime
import pyspark.sql.functions as F
from pyspark.sql.types import *
from pyspark.sql.window import Window
customerProductDf = spark.createDataFrame(
[(123, 5556),
(123, 7987),
(456, 7778),
(456, 7987)],
StructType([
StructField("CustomerId", IntegerType(), True),
StructField("ProductId", IntegerType(), True)
]))
customersDf = spark.createDataFrame(
[(123, 100, datetime.strptime('2023-05-01 07:40:56', '%Y-%m-%d %H:%M:%S')),
(456, 70, datetime.strptime('2023-05-01 12:20:50', '%Y-%m-%d %H:%M:%S')),
(456, 70, datetime.strptime('2023-05-01 17:30:50', '%Y-%m-%d %H:%M:%S'))],
StructType([
StructField("CustomerId", IntegerType(), True),
StructField("Budget", IntegerType(), True),
StructField("TimeStamp", TimestampType(), True)
]))
productsDf = spark.createDataFrame(
[(5556, 5, datetime.strptime('2023-05-01 06:40:56', '%Y-%m-%d %H:%M:%S')),
(5556, 90, datetime.strptime('2023-05-01 05:40:56', '%Y-%m-%d %H:%M:%S')),
(7778, 20, datetime.strptime('2023-05-01 12:20:50', '%Y-%m-%d %H:%M:%S')),
(7987, 60, datetime.strptime('2023-05-01 05:50:00', '%Y-%m-%d %H:%M:%S')),
(7987, 50, datetime.strptime('2023-05-04 05:50:00', '%Y-%m-%d %H:%M:%S'))],
StructType([
StructField("ProductId", IntegerType(), True),
StructField("Price", IntegerType(), True),
StructField("TimeStamp", TimestampType(), True)
]))
Calculations:
window = Window.partitionBy(customersDf.CustomerId, productsDf.ProductId).orderBy(productsDf.TimeStamp)
customerVsDeltaDf = customerProductDf \
.join(customersDf, 'CustomerId') \
.join(productsDf, 'ProductId') \
.filter((customersDf.TimeStamp > productsDf.TimeStamp) & (customersDf.Budget > productsDf.Price)) \
.withColumn("LatestPrice", F.last(productsDf.Price).over(window)) \
.drop(productsDf.Price) \
.distinct() \ # Drop any duplicates to not affect the count
.groupBy(customersDf.CustomerId) \
.agg( \
F.count(productsDf.ProductId).alias('Count'), \
F.max(customersDf.Budget - F.col('LatestPrice')).alias('MaxPriceDiff') \
)
Result:
>>> customerVsDeltaDf.show()
+----------+-----+------------+
|CustomerId|Count|MaxPriceDiff|
+----------+-----+------------+
| 456| 3| 50|
| 123| 2| 95|
+----------+-----+------------+
>>> customerVsDeltaDf.agg(F.sum('Count').alias("TotalCount")).show()
+----------+
|TotalCount|
+----------+
| 5|
+----------+