Search code examples
pythonpandasapache-sparkpysparkouter-join

Join operation equivalent of nested for loop for pyspark?


I have below tables -

Audit ID Customer ID Budget TimeStamp
1 123 100 2023-05-01 07:40:56
2 456 70 2023-05-01 12:20:50
3 456 70 2023-05-01 17:30:50
Audit ID Product ID Price TimeStamp
5 5556 5 2023-05-01 06:40:56
6 5556 90 2023-05-01 06:40:56
7 7778 20 2023-05-01 12:20:50
9 7987 60 2023-05-01 05:50:00
10 7987 50 2023-05-04 05:50:00
Customer ID Product ID
123 5556
123 7987
456 7778
456 7987

Problem statement - Find count where customer budget is more than product price (pick latest product price before customer budget timestamp) and also the max delta between customer budget and product price.

Basically i need query equivalent of below python code for PySpark , i ran below code on pandas and it worked fine for small dataset but for large data set pandas is not able to process it. I came across PySpark and read that its faster but it seems we cannot write nested loop in pyspark.

count_intances_budget_more_than_price = 0;
map_customer_id_max_delta = {}
processed_product_for_customer = new set()

for cusomter_row in customer_dataset:
  max_delta = 0
  if customer_id in map_customer_id_max_delta:
      max_delta = map_customer_id_max_delta.get(customer_id)
  for product_row in product_dataset:
    if product_id in map_customer_id_product_id[customer_id]:
      if product_id not in processed_product_for_customer:
        processed_product_for_customer.add(product_id)
        if product_timestamp < customer_timestamp and product_price < customer_budget:
          count_intances_budget_more_than_price +=1
          max_delta = max(max_delta,customer_budget - product_price )
  map_customer_id_max_delta[customer_id] = max_delta 
  processed_product_for_customer.clear()

Solution

  • I think you just need to join the 3 tables and aggregate on customer-id and count the number of matched products and calculate the max difference for each customer

    Input:

    from datetime import datetime
    import pyspark.sql.functions as F
    from pyspark.sql.types import *
    from pyspark.sql.window import Window
    
    customerProductDf = spark.createDataFrame(
        [(123, 5556),
         (123, 7987),
         (456, 7778),
         (456, 7987)],
         StructType([
            StructField("CustomerId", IntegerType(), True), 
            StructField("ProductId", IntegerType(), True)
        ]))
    customersDf = spark.createDataFrame(
        [(123, 100, datetime.strptime('2023-05-01 07:40:56', '%Y-%m-%d %H:%M:%S')),
         (456, 70, datetime.strptime('2023-05-01 12:20:50', '%Y-%m-%d %H:%M:%S')),
         (456, 70, datetime.strptime('2023-05-01 17:30:50', '%Y-%m-%d %H:%M:%S'))],
         StructType([
            StructField("CustomerId", IntegerType(), True), 
            StructField("Budget", IntegerType(), True), 
            StructField("TimeStamp", TimestampType(), True)
        ]))
    productsDf = spark.createDataFrame(
        [(5556, 5, datetime.strptime('2023-05-01 06:40:56', '%Y-%m-%d %H:%M:%S')),
         (5556, 90, datetime.strptime('2023-05-01 05:40:56', '%Y-%m-%d %H:%M:%S')),
         (7778, 20, datetime.strptime('2023-05-01 12:20:50', '%Y-%m-%d %H:%M:%S')),
         (7987, 60, datetime.strptime('2023-05-01 05:50:00', '%Y-%m-%d %H:%M:%S')),
         (7987, 50, datetime.strptime('2023-05-04 05:50:00', '%Y-%m-%d %H:%M:%S'))],
         StructType([
            StructField("ProductId", IntegerType(), True), 
            StructField("Price", IntegerType(), True), 
            StructField("TimeStamp", TimestampType(), True)
        ]))
    

    Calculations:

    window = Window.partitionBy(customersDf.CustomerId, productsDf.ProductId).orderBy(productsDf.TimeStamp)
    customerVsDeltaDf = customerProductDf \
            .join(customersDf, 'CustomerId') \
            .join(productsDf, 'ProductId') \
            .filter((customersDf.TimeStamp > productsDf.TimeStamp) & (customersDf.Budget > productsDf.Price)) \
            .withColumn("LatestPrice", F.last(productsDf.Price).over(window)) \
            .drop(productsDf.Price) \
            .distinct() \ # Drop any duplicates to not affect the count
            .groupBy(customersDf.CustomerId) \
            .agg( \
                F.count(productsDf.ProductId).alias('Count'), \
                F.max(customersDf.Budget - F.col('LatestPrice')).alias('MaxPriceDiff') \
            )
    

    Result:

    >>> customerVsDeltaDf.show()
    +----------+-----+------------+                                                    
    |CustomerId|Count|MaxPriceDiff|
    +----------+-----+------------+
    |       456|    3|          50|
    |       123|    2|          95|
    +----------+-----+------------+
    
    >>> customerVsDeltaDf.agg(F.sum('Count').alias("TotalCount")).show()
    +----------+
    |TotalCount|
    +----------+
    |         5|
    +----------+