I understand that 'for' and 'while' loops are generally to-be-avoided when using Spark. My question is about optimizing a 'while' loop, though if I'm missing a solution that makes it unnecessary, I am all ears.
I'm not sure I can demonstrate the issue (very long processing times, compounding as the loop goes on) with toy data, but here is some pseudo code:
### I have a function - called 'enumerator' - which involves several joins and window functions.
# I run this function on my base dataset, df0, and return df1
df1 = enumerator(df0, param1 = apple, param2 = banana)
# Check for some condition in df1, then count number of rows in the result
counter = df1 \
.filter(col('X') == some_condition) \
.count()
# If there are rows meeting this condition, start a while loop
while counter > 0:
print('Starting with counter: ', str(counter))
# Run the enumerator function on df1 again
df2 = enumerator(df1, param1= apple, param2 = banana)
# Check for the condition again, then continue the while loop if necessary
counter = df2 \
.filter(col('X') == some_condition) \
.count()
df1 = df2
# After the while loop finishes, I take the last resulting dataframe and I will do several more operations and analyses downstream
final_df = df2
An essential aspect of the enumerator function is to 'look back' on a sequence in a window, and so it may take several runs before all the necessary corrections are made.
In my heart, I know this is ugly but the windowing/ranking/sequential analysis within the function is critical. My understanding is that the underlying Spark query plan gets more and more convoluted as the loop continues. Are there any best practices I should adopt in this situation? Should I be cacheing at any point - either before the while loop starts, or within the loop itself?
You definitely should cache/persist the dataframes, otherwise every iteration in the while
loop will start from scratch from df0
. Also you may want to unpersist the used dataframes to free up disk/memory space.
Another point to optimize is not to do a count
, but use a cheaper operation, such as df.take(1)
. If that returns nothing then counter == 0
.
df1 = enumerator(df0, param1 = apple, param2 = banana)
df1.cache()
# Check for some condition in df1, then count number of rows in the result
counter = len(df1.filter(col('X') == some_condition).take(1))
while counter > 0:
print('Starting with counter: ', str(counter))
df2 = enumerator(df1, param1 = apple, param2 = banana)
df2.cache()
counter = len(df2.filter(col('X') == some_condition).take(1))
df1.unpersist() # unpersist df1 as it will be overwritten
df1 = df2
final_df = df2