Search code examples
apache-sparkpysparkout-of-memoryheap-memory

Spark goes java heap space out of memory with a small collect


I've got a problem with Spark, its driver and an OoM issue.

Currently I have a dataframe which is being built with several, joined sources (actually different tables in parquet format), and there are thousands of tuples. They have a date which represents the date of creation of the record, and distinctly they are a few.

I do the following:

from pyspark.sql.functions import year, month

# ...

selectionRows = inputDataframe.select(year('registration_date').alias('year'), month('registration_date').alias('month')).distinct()
selectionRows.show() # correctly shows 8 tuples
selectionRows = selectionRows.collect() # goes heap space OoM
print(selectionRows)

Reading the memory consumption statistics shows that the driver does not exceed ~60%. I thought that the driver should load only the distinct subset, not the entire dataframe.

Am I missing something? Is it possible to collect those few rows in a smarter way? I need them as a pushdown predicate to load a secondary dataframe.

Thank you very much!

EDIT / SOLUTION

After reading the comments and elaborating my personal needs, I cached the dataframe at every "join/elaborate" step, so that in a timeline I do the following:

  • Join with loaded table
  • Queue required transformations
  • Apply the cache transformation
  • Print the count to keep track of cardinality (mainly for tracking / debugging purposes) and thus apply all transformations + cache
  • Unpersist the cache of the previous sibiling step, if available (tick/tock paradigm)

This reduced some complex ETL jobs down to 20% of the original time (as previously it was applying the transformations of each previous step at each count).

Lesson learned :)


Solution

  • After reading the comments, I elaborated the solution for my use case.

    As mentioned in the question, I join several tables one with each other in a "target dataframe", and at each iteration I do some transformations, like so:

    # n-th table work
    target = target.join(other, how='left')
    target = target.filter(...)
    target = target.withColumn('a', 'b')
    target = target.select(...)
    print(f'Target after table "other": {target.count()}')
    

    The problem of slowliness / OoM was that Spark was forced to do all the transformations from start to finish at each table due to the ending count, making it slower and slower at each table / iteration.

    The solution I found is to cache the dataframe at each iteration, like so:

    cache: DataFrame = null
    
    # ...
    
    # n-th table work
    target = target.join(other, how='left')
    target = target.filter(...)
    target = target.withColumn('a', 'b')
    target = target.select(...)
    
    target = target.cache()
    target_count = target.count() # actually do the cache
    if cache:
      cache.unpersist() # free the memory from the old cache version
    cache = target
    
    print(f'Target after table "other": {target_count}')