Search code examples
apache-sparkpysparkwhile-loopsequential

Cacheing and Loops in (Py)Spark


I understand that 'for' and 'while' loops are generally to-be-avoided when using Spark. My question is about optimizing a 'while' loop, though if I'm missing a solution that makes it unnecessary, I am all ears.

I'm not sure I can demonstrate the issue (very long processing times, compounding as the loop goes on) with toy data, but here is some pseudo code:

### I have a function - called 'enumerator' - which involves several joins and window functions. 
# I run this function on my base dataset, df0, and return df1
df1 = enumerator(df0, param1 = apple, param2 = banana)

# Check for some condition in df1, then count number of rows in the result
counter = df1 \
.filter(col('X') == some_condition) \
.count()

# If there are rows meeting this condition, start a while loop
while counter > 0:
  print('Starting with counter: ', str(counter))
  
  # Run the enumerator function on df1 again
  df2 = enumerator(df1, param1= apple, param2 = banana)
  
  # Check for the condition again, then continue the while loop if necessary
  counter = df2 \
  .filter(col('X') == some_condition) \
  .count()
  
  df1 = df2

# After the while loop finishes, I take the last resulting dataframe and I will do several more operations and analyses downstream  
final_df = df2

An essential aspect of the enumerator function is to 'look back' on a sequence in a window, and so it may take several runs before all the necessary corrections are made.

In my heart, I know this is ugly but the windowing/ranking/sequential analysis within the function is critical. My understanding is that the underlying Spark query plan gets more and more convoluted as the loop continues. Are there any best practices I should adopt in this situation? Should I be cacheing at any point - either before the while loop starts, or within the loop itself?


Solution

  • You definitely should cache/persist the dataframes, otherwise every iteration in the while loop will start from scratch from df0. Also you may want to unpersist the used dataframes to free up disk/memory space.

    Another point to optimize is not to do a count, but use a cheaper operation, such as df.take(1). If that returns nothing then counter == 0.

    df1 = enumerator(df0, param1 = apple, param2 = banana)
    df1.cache()
    
    # Check for some condition in df1, then count number of rows in the result
    counter = len(df1.filter(col('X') == some_condition).take(1))
    
    while counter > 0:
      print('Starting with counter: ', str(counter))
      
      df2 = enumerator(df1, param1 = apple, param2 = banana)
      df2.cache()
    
      counter = len(df2.filter(col('X') == some_condition).take(1))
      df1.unpersist()    # unpersist df1 as it will be overwritten
      
      df1 = df2
    
    final_df = df2