Search code examples
pythonapache-sparkjoinpysparkdatabricks

pyspark code on databricks never completes execution and hang in between


I have two data frames: df_selected and df_filtered_mins_60

df_filtered_mins_60.columns()

Output:["CSku", "start_timestamp", "end_timestamp"]

df_selected.columns()

Output:["DATEUPDATED", "DATE", "HOUR", "CPSKU", "BB_Status", "ActivePrice", "PrevPrice", "MinPrice", "AsCost", "MinMargin", "CPT", "Comp_Price", "AP_MSG"]

df_selected.count()

Output: 7,816,521

df_filtered_mins_60.count()

Output: 112,397

What i want to implement is: iterate through df_filtered_mins_60, for each row take:
start_time = start_timestamp
stop_time = end_timestamp
sku = CSku
Apply below conditions on df_selected WHEN:
DATEUPDATED is equal to or in between start_time and stop_time
AND CPSKU = sku
THEN assign all the rows satisfying this condition with a constant number i. continue doing this until the end of the rows in df_filtered_mins_60. After each update increment i=i+1

Code I wrote is given below. this code never executes instead gets stuck somewhere. It would keep running for hours until I forcefully stop it.

i = 1
df_selected = df_selected.withColumn("counter", lit(0))

# Iterate through each row of df_filtered_mins_60
for row in df_filtered_mins_60.collect():
    sku = row['CSku']
    start_time = row['start_timestamp']
    stop_time = row['stop_timestamp']

    # Apply conditions on df_selected and update "counter" column
    df_selected = df_selected.withColumn("counter", 
                                         when((df_selected.DATEUPDATED >= start_time) & 
                                              (df_selected.DATEUPDATED <= stop_time) &
                                              (df_selected.CPSKU == sku),
                                              lit(i)).otherwise(df_selected.counter))
    
    i += 1

# Display the updated df_selected DataFrame with the "counter" column
display(df_selected)

I am assigning counters because I need a set of rows from df_selected which are in between certain time windows for each SKU and this information is present in df_filtered_mins_60. After assigning a counter I need to perform aggregates on other columns in df_selected. Basically, for each window, I need some insights into what was happening during certain time windows.

I need to get the right code in Pyspark to run on Databricks.

Generate Sample Data:

from pyspark.sql import SparkSession
from pyspark.sql.functions import to_timestamp
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType

# Initialize SparkSession
spark_a = SparkSession.builder \
    .appName("Create DataFrame") \
    .getOrCreate()

schema = StructType([
    StructField("DATEUPDATED", StringType(), True),
    StructField("DATE", StringType(), True),
    StructField("HOUR", IntegerType(), True),
    StructField("CPSKU", StringType(), True),
    StructField("BB_Status", IntegerType(), True),
    StructField("ActivePrice", DoubleType(), True),
    StructField("PrevPrice", DoubleType(), True),
    StructField("MinPrice", DoubleType(), True),
    StructField("AsCost", DoubleType(), True),
    StructField("MinMargin", DoubleType(), True),
    StructField("CPT", DoubleType(), True),
    StructField("Comp_Price", DoubleType(), True)
])

data=[('2024-01-01T19:45:39.151+00:00','2024-01-01',0,'MSAN10115836',0,14.86,14.86,14.86,12.63,0.00,13.90,5.84) ,
('2024-01-01T19:55:10.904+00:00','2024-01-01',0,'MSAN10115836',0,126.04,126.04,126.04,108.96,0.00,0.00,93.54),
('2024-01-01T20:35:10.904+00:00','2024-01-01',0,'MSAN10115836',0,126.04,126.04,126.04,108.96,0.00,0.00,93.54),
('2024-01-15T12:55:18.528+00:00','2024-01-01',1,'PFXNDDF4OX',1,18.16,18.16,10.56,26.85,-199.00,18.16,34.10) ,
('2024-01-15T13:25:18.528+00:00','2024-01-01',1,'PFXNDDF4OX',1,18.16,18.16,10.56,26.85,-199.00,18.16,34.10) ,
('2024-01-15T13:35:18.528+00:00','2024-01-01',1,'PFXNDDF4OX',1,18.16,18.16,10.56,26.85,-199.00,18.16,34.10) ,
('2024-01-15T13:51:09.574+00:00','2024-01-01',1,'PFXNDDF4OX',1,20.16,18.16,10.56,26.85,-199.00,18.16,34.10) ,
('2024-01-15T07:28:48.265+00:00','2024-01-01',1,'DEWNDCB135C',0,44.93,44.93,44.93,38.09,0.25,26.9,941.26),
('2024-01-15T07:50:32.412+00:00','2024-01-01',1,'DEWNDCB135C',0,44.93,44.93,44.93,38.09,0.25,26.9,941.26),
('2024-01-15T07:52:32.412+00:00','2024-01-01',1,'DEWNDCB135C',0,44.93,44.93,44.93,38.09,0.25,26.9,941.26)]

df_selected = spark.createDataFrame(data, schema=schema)
df_selected = df_selected.withColumn("DateUpdated", to_timestamp(df_selected["DATEUPDATED"], "yyyy-MM-dd'T'HH:mm:ss.SSS'+00:00'"))
display(df_selected)

Second Dataframe:

schema = StructType([
    StructField("CPSKU", StringType(), True),
    StructField("start_timestamp", StringType(), True),
    StructField("stop_timestamp", StringType(), True)
])
data_2=[('MSAN10115836','2024-01-01T19:45:39.151+00:00','2024-01-01T20:35:10.904+00:00'),
('MSAN10115836','2024-01-08T06:04:16.484+00:00','2024-01-08T06:42:14.912+00:00'),
('DEWNDCB135C','2024-01-15T07:28:48.265+00:00','2024-01-15T07:52:32.412+00:00'),
('DEWNDCB135C','2024-01-15T11:37:56.698+00:00','2024-01-15T12:35:09.693+00:00'),
('PFXNDDF4OX','2024-01-15T12:55:18.528+00:00','2024-01-15T13:51:09.574+00:00'),
('PFXNDDF4OX','2024-01-15T19:25:10.150+00:00','2024-01-15T20:24:36.385+00:00')]

df_filtered_mins_60 = spark.createDataFrame(data_2, schema=schema)
df_filtered_mins_60 = df_filtered_mins_60.withColumn("start_timestamp", to_timestamp(df_filtered_mins_60["start_timestamp"], "yyyy-MM-dd'T'HH:mm:ss.SSS'+00:00'"))
df_filtered_mins_60 = df_filtered_mins_60.withColumn("stop_timestamp", to_timestamp(df_filtered_mins_60["stop_timestamp"], "yyyy-MM-dd'T'HH:mm:ss.SSS'+00:00'"))
display(df_filtered_mins_60)

Solution

  • As discussed in the comments, the reason your code takes forever to execute is because using a loop with .collect() will load all the data into the driver (and doesn't take advantage of the data being distributed) and then for every iteration of the loop you're overwriting the entire column.

    If you need an ordered counter row, you can instead assign a row number to each row of df_filtered_mins_60 and then left join df_selected with df_filtered_mins_60 on the conditions:

    [
        df_selected.CPSKU == df_filtered_mins_60.CPSKU,
        df_selected.DATEUPDATED >= df_filtered_mins_60.start_timestamp,
        df_selected.DATEUPDATED <= df_filtered_mins_60.stop_timestamp,
    ]
    

    This will keep all of the rows of df_selected including those that don't meet the join conditions. Then you can assign 0 to any rows where counter is NULL.

    Below is a fully reproducible example (and I have added a row in df_selected that won't meet the join condition just to show what happens):

    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F
    from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType
    from pyspark.sql.window import Window
    
    # Initialize SparkSession
    spark = SparkSession.builder \
        .appName("Create DataFrame") \
        .getOrCreate()
    
    schema = StructType([
        StructField("DATEUPDATED", StringType(), True),
        StructField("DATE", StringType(), True),
        StructField("HOUR", IntegerType(), True),
        StructField("CPSKU", StringType(), True),
        StructField("BB_Status", IntegerType(), True),
        StructField("ActivePrice", DoubleType(), True),
        StructField("PrevPrice", DoubleType(), True),
        StructField("MinPrice", DoubleType(), True),
        StructField("AsCost", DoubleType(), True),
        StructField("MinMargin", DoubleType(), True),
        StructField("CPT", DoubleType(), True),
        StructField("Comp_Price", DoubleType(), True)
    ])
    
    data=[('2024-01-01T19:45:39.151+00:00','2024-01-01',0,'MSAN10115836',0,14.86,14.86,14.86,12.63,0.00,13.90,5.84) ,
    ('2024-01-01T19:55:10.904+00:00','2024-01-01',0,'MSAN10115836',0,126.04,126.04,126.04,108.96,0.00,0.00,93.54),
    ('2024-01-01T20:35:10.904+00:00','2024-01-01',0,'MSAN10115836',0,126.04,126.04,126.04,108.96,0.00,0.00,93.54),
    ('2024-01-15T12:55:18.528+00:00','2024-01-01',1,'PFXNDDF4OX',1,18.16,18.16,10.56,26.85,-199.00,18.16,34.10) ,
    ('2024-01-15T13:25:18.528+00:00','2024-01-01',1,'PFXNDDF4OX',1,18.16,18.16,10.56,26.85,-199.00,18.16,34.10) ,
    ('2024-01-15T13:35:18.528+00:00','2024-01-01',1,'PFXNDDF4OX',1,18.16,18.16,10.56,26.85,-199.00,18.16,34.10) ,
    ('2024-01-15T13:51:09.574+00:00','2024-01-01',1,'PFXNDDF4OX',1,20.16,18.16,10.56,26.85,-199.00,18.16,34.10) ,
    ('2024-01-15T07:28:48.265+00:00','2024-01-01',1,'DEWNDCB135C',0,44.93,44.93,44.93,38.09,0.25,26.9,941.26),
    ('2024-01-15T07:50:32.412+00:00','2024-01-01',1,'DEWNDCB135C',0,44.93,44.93,44.93,38.09,0.25,26.9,941.26),
    ('2024-01-15T07:52:32.412+00:00','2024-01-01',1,'DEWNDCB135C',0,44.93,44.93,44.93,38.09,0.25,26.9,941.26),
    ('2027-01-15T07:52:32.412+00:00','2024-01-01',1,'TEST',0,44.93,44.93,44.93,38.09,0.25,26.9,941.26)]
    
    df_selected = spark.createDataFrame(data, schema=schema)
    df_selected = df_selected.withColumn("DATEUPDATED", to_timestamp(df_selected["DATEUPDATED"], "yyyy-MM-dd'T'HH:mm:ss.SSS'+00:00'"))
    
    schema = StructType([
        StructField("CPSKU", StringType(), True),
        StructField("start_timestamp", StringType(), True),
        StructField("stop_timestamp", StringType(), True)
    ])
    data_2=[('MSAN10115836','2024-01-01T19:45:39.151+00:00','2024-01-01T20:35:10.904+00:00'),
    ('MSAN10115836','2024-01-08T06:04:16.484+00:00','2024-01-08T06:42:14.912+00:00'),
    ('DEWNDCB135C','2024-01-15T07:28:48.265+00:00','2024-01-15T07:52:32.412+00:00'),
    ('DEWNDCB135C','2024-01-15T11:37:56.698+00:00','2024-01-15T12:35:09.693+00:00'),
    ('PFXNDDF4OX','2024-01-15T12:55:18.528+00:00','2024-01-15T13:51:09.574+00:00'),
    ('PFXNDDF4OX','2024-01-15T19:25:10.150+00:00','2024-01-15T20:24:36.385+00:00')]
    
    df_filtered_mins_60 = spark.createDataFrame(data_2, schema=schema)
    df_filtered_mins_60 = df_filtered_mins_60.withColumn("start_timestamp", F.to_timestamp(df_filtered_mins_60["start_timestamp"], "yyyy-MM-dd'T'HH:mm:ss.SSS'+00:00'"))
    df_filtered_mins_60 = df_filtered_mins_60.withColumn("stop_timestamp", F.to_timestamp(df_filtered_mins_60["stop_timestamp"], "yyyy-MM-dd'T'HH:mm:ss.SSS'+00:00'"))
    
    w = Window().orderBy(F.lit(0))
    df_filtered_mins_60 = df_filtered_mins_60.withColumn("counter", F.row_number().over(w))
    
    df_joined = df_selected.join(
        df_filtered_mins_60, 
        on=[
            df_selected.CPSKU == df_filtered_mins_60.CPSKU,
            df_selected.DATEUPDATED >= df_filtered_mins_60.start_timestamp,
            df_selected.DATEUPDATED <= df_filtered_mins_60.stop_timestamp,
        ],
        how='left' # keep all rows from df_selected
    ).drop(
        df_filtered_mins_60.CPSKU, 
        df_filtered_mins_60.start_timestamp, 
        df_filtered_mins_60.stop_timestamp
    ).withColumn(
        'counter', F.coalesce(F.col('counter'), F.lit(0))
    ).orderBy(
        'CPSKU','DATEUPDATED'
    ).show()
    

    df_selected looks like this:

    +-----------------------+----------+----+------------+---------+-----------+---------+--------+------+---------+-----+----------+
    |DATEUPDATED            |DATE      |HOUR|CPSKU       |BB_Status|ActivePrice|PrevPrice|MinPrice|AsCost|MinMargin|CPT  |Comp_Price|
    +-----------------------+----------+----+------------+---------+-----------+---------+--------+------+---------+-----+----------+
    |2024-01-01 19:45:39.151|2024-01-01|0   |MSAN10115836|0        |14.86      |14.86    |14.86   |12.63 |0.0      |13.9 |5.84      |
    |2024-01-01 19:55:10.904|2024-01-01|0   |MSAN10115836|0        |126.04     |126.04   |126.04  |108.96|0.0      |0.0  |93.54     |
    |2024-01-01 20:35:10.904|2024-01-01|0   |MSAN10115836|0        |126.04     |126.04   |126.04  |108.96|0.0      |0.0  |93.54     |
    |2024-01-15 12:55:18.528|2024-01-01|1   |PFXNDDF4OX  |1        |18.16      |18.16    |10.56   |26.85 |-199.0   |18.16|34.1      |
    |2024-01-15 13:25:18.528|2024-01-01|1   |PFXNDDF4OX  |1        |18.16      |18.16    |10.56   |26.85 |-199.0   |18.16|34.1      |
    |2024-01-15 13:35:18.528|2024-01-01|1   |PFXNDDF4OX  |1        |18.16      |18.16    |10.56   |26.85 |-199.0   |18.16|34.1      |
    |2024-01-15 13:51:09.574|2024-01-01|1   |PFXNDDF4OX  |1        |20.16      |18.16    |10.56   |26.85 |-199.0   |18.16|34.1      |
    |2024-01-15 07:28:48.265|2024-01-01|1   |DEWNDCB135C |0        |44.93      |44.93    |44.93   |38.09 |0.25     |26.9 |941.26    |
    |2024-01-15 07:50:32.412|2024-01-01|1   |DEWNDCB135C |0        |44.93      |44.93    |44.93   |38.09 |0.25     |26.9 |941.26    |
    |2024-01-15 07:52:32.412|2024-01-01|1   |DEWNDCB135C |0        |44.93      |44.93    |44.93   |38.09 |0.25     |26.9 |941.26    |
    |2027-01-15 07:52:32.412|2024-01-01|1   |TEST        |0        |44.93      |44.93    |44.93   |38.09 |0.25     |26.9 |941.26    |
    +-----------------------+----------+----+------------+---------+-----------+---------+--------+------+---------+-----+----------+
    

    df_filtered_mins_60 looks like this:

    +------------+-----------------------+-----------------------+-------+
    |CPSKU       |start_timestamp        |stop_timestamp         |counter|
    +------------+-----------------------+-----------------------+-------+
    |MSAN10115836|2024-01-01 19:45:39.151|2024-01-01 20:35:10.904|1      |
    |MSAN10115836|2024-01-08 06:04:16.484|2024-01-08 06:42:14.912|2      |
    |DEWNDCB135C |2024-01-15 07:28:48.265|2024-01-15 07:52:32.412|3      |
    |DEWNDCB135C |2024-01-15 11:37:56.698|2024-01-15 12:35:09.693|4      |
    |PFXNDDF4OX  |2024-01-15 12:55:18.528|2024-01-15 13:51:09.574|5      |
    |PFXNDDF4OX  |2024-01-15 19:25:10.15 |2024-01-15 20:24:36.385|6      |
    +------------+-----------------------+-----------------------+-------+
    

    And df_joined looks like this:

    +-----------------------+----------+----+------------+---------+-----------+---------+--------+------+---------+-----+----------+-------+
    |DATEUPDATED            |DATE      |HOUR|CPSKU       |BB_Status|ActivePrice|PrevPrice|MinPrice|AsCost|MinMargin|CPT  |Comp_Price|counter|
    +-----------------------+----------+----+------------+---------+-----------+---------+--------+------+---------+-----+----------+-------+
    |2024-01-15 07:28:48.265|2024-01-01|1   |DEWNDCB135C |0        |44.93      |44.93    |44.93   |38.09 |0.25     |26.9 |941.26    |3      |
    |2024-01-15 07:50:32.412|2024-01-01|1   |DEWNDCB135C |0        |44.93      |44.93    |44.93   |38.09 |0.25     |26.9 |941.26    |3      |
    |2024-01-15 07:52:32.412|2024-01-01|1   |DEWNDCB135C |0        |44.93      |44.93    |44.93   |38.09 |0.25     |26.9 |941.26    |3      |
    |2024-01-01 19:45:39.151|2024-01-01|0   |MSAN10115836|0        |14.86      |14.86    |14.86   |12.63 |0.0      |13.9 |5.84      |1      |
    |2024-01-01 19:55:10.904|2024-01-01|0   |MSAN10115836|0        |126.04     |126.04   |126.04  |108.96|0.0      |0.0  |93.54     |1      |
    |2024-01-01 20:35:10.904|2024-01-01|0   |MSAN10115836|0        |126.04     |126.04   |126.04  |108.96|0.0      |0.0  |93.54     |1      |
    |2024-01-15 12:55:18.528|2024-01-01|1   |PFXNDDF4OX  |1        |18.16      |18.16    |10.56   |26.85 |-199.0   |18.16|34.1      |5      |
    |2024-01-15 13:25:18.528|2024-01-01|1   |PFXNDDF4OX  |1        |18.16      |18.16    |10.56   |26.85 |-199.0   |18.16|34.1      |5      |
    |2024-01-15 13:35:18.528|2024-01-01|1   |PFXNDDF4OX  |1        |18.16      |18.16    |10.56   |26.85 |-199.0   |18.16|34.1      |5      |
    |2024-01-15 13:51:09.574|2024-01-01|1   |PFXNDDF4OX  |1        |20.16      |18.16    |10.56   |26.85 |-199.0   |18.16|34.1      |5      |
    |2027-01-15 07:52:32.412|2024-01-01|1   |TEST        |0        |44.93      |44.93    |44.93   |38.09 |0.25     |26.9 |941.26    |0      |
    +-----------------------+----------+----+------------+---------+-----------+---------+--------+------+---------+-----+----------+-------+