Search code examples
pythonapache-sparkpysparkrow

PySpark: Create new rows (explode) based on a number in a column and multiple conditions


I have a dataframe with a few columns, a unique ID, a month, and a split. I need to explode the dataframe and create new rows for each unique combination of id, month, and split. The number to explode has already been calculated and is stored in the column, bad_call_dist. For example, if ID is 12345, month is Jan, split is 'A', and bad_call_dist is 6, I need to have a total of 6 rows. This process must repeat for each unique combination.

I have code that works for small datasets, however I need to apply it to a much, much larger dataframe and it times out each time. What the code below does is extract a single-row dataframe from the original data with a temporary range column representing how many rows must exist for a unique col combination. I then use explode() to generate the new rows and union that into a master dataframe. I'm looking for assistance to optimize the code and speed up processing times while producing the same outcome:

# unique ID-month-split combinations for the data
idMonthSplits = call_data.select('id', 'month', 'split').distinct().collect()

# set the schema to all cols except the bad call flag, which is set to 1 in the loop
master_explode = spark.createDataFrame([], schema=call_data.select([col for col in call_data.columns if col != 'bad_call_flag']).schema)

# loop
for ims in idMonthSplits:

id = ims ['id']
month = ims ['month']
split = ims ['split']

# explode the df one row per n, where n is the value in bad_call_dist.
explode_df = exploded.filter((exploded['id'] == id) & (exploded['month'] == month) & (exploded['split'] == split))\
    .withColumn('bad_call_flag', F.lit(1))

try:
    
    # extract the value that represents the number of rows to explode
    expVal = explode_df.select(F.first(F.col("bad_call_dist")).cast("int")).first()[0]

    # range that is used by explode() to convert single row to multiple rows
    explode_df = explode_df.withColumn(
        'range',
        F.array(
            [F.lit(i) for i in range(expVal + 1)]
        )
    )

    # explode the df, then drop cols no longer needed for union
    explode_df = explode_df.withColumn('explode', F.explode(F.col('range')))\
        .drop(*['explode', 'range', 'bad_call_dist'])

    # union to master df
    master_explode = master_explode.unionAll(explode_df)

# if the explode value is 0, no need to expand rows. This triggers to avoid an error.
except:
    continue

Solution

  • loops are almost always disastrous in spark. it is better to use spark functions as much as you can, as they can be internally optimized, and your situation can be solved using array_repeat() within expr().

    Here's an example

    # given the following data
    # +---+-----+-----+-------------+
    # | id|month|split|bad_call_dist|
    # +---+-----+-----+-------------+
    # |  1|  Jan|    A|            6|
    # |  1|  Feb|    A|            8|
    # +---+-----+-----+-------------+
    
    # create a dummy array to explode using `array_repeat` and explode it
    data_sdf. \
        withColumn('dummy_arr', func.expr('array_repeat(1, cast(bad_call_dist as int))')). \
        selectExpr(*data_sdf.columns, 'explode(dummy_arr) as exp_dummy'). \
        show()
    
    # +---+-----+-----+-------------+---------+
    # |id |month|split|bad_call_dist|exp_dummy|
    # +---+-----+-----+-------------+---------+
    # |1  |Jan  |A    |6            |1        |
    # |1  |Jan  |A    |6            |1        |
    # |1  |Jan  |A    |6            |1        |
    # |1  |Jan  |A    |6            |1        |
    # |1  |Jan  |A    |6            |1        |
    # |1  |Jan  |A    |6            |1        |
    # |1  |Feb  |A    |8            |1        |
    # |1  |Feb  |A    |8            |1        |
    # |1  |Feb  |A    |8            |1        |
    # |1  |Feb  |A    |8            |1        |
    # |1  |Feb  |A    |8            |1        |
    # |1  |Feb  |A    |8            |1        |
    # |1  |Feb  |A    |8            |1        |
    # |1  |Feb  |A    |8            |1        |
    # +---+-----+-----+-------------+---------+
    

    notice that I used array_repeat within expr. that's because you want the number of times to repeat to come from a column and spark native function does not accept column in the second parameter, but the SQL function does.

    data_sdf. \
        withColumn('dummy_arr', func.expr('array_repeat(1, cast(bad_call_dist as int))')). \
        show(truncate=False)
    
    # +---+-----+-----+-------------+------------------------+
    # |id |month|split|bad_call_dist|dummy_arr               |
    # +---+-----+-----+-------------+------------------------+
    # |1  |Jan  |A    |6            |[1, 1, 1, 1, 1, 1]      |
    # |1  |Feb  |A    |8            |[1, 1, 1, 1, 1, 1, 1, 1]|
    # +---+-----+-----+-------------+------------------------+