Search code examples
pythoncachingapache-sparkpysparkapache-spark-sql

Un-persisting all dataframes in (py)spark


I am a spark application with several points where I would like to persist the current state. This is usually after a large step, or caching a state that I would like to use multiple times. It appears that when I call cache on my dataframe a second time, a new copy is cached to memory. In my application, this leads to memory issues when scaling up. Even though, a given dataframe is a maximum of about 100 MB in my current tests, the cumulative size of the intermediate results grows beyond the alloted memory on the executor. See below for a small example that shows this behavior.

cache_test.py:

from pyspark import SparkContext, HiveContext

spark_context = SparkContext(appName='cache_test')
hive_context = HiveContext(spark_context)

df = (
    hive_context
    .read
    .format('com.databricks.spark.csv')
    .load('simple_data.csv')
)
df.cache()
df.show()

df = df.withColumn('C1+C2', df['C1'] + df['C2'])
df.cache()
df.show()

spark_context.stop()

simple_data.csv:

1,2,3
4,5,6
7,8,9

Looking at the application UI, there is a copy of the original dataframe, in adition to the one with the new column. I can remove the original copy by calling df.unpersist() before the withColumn line. Is this the recommended way to remove cached intermediate result (i.e. call unpersist before every cache()).

Also, is it possible to purge all cached objects. In my application, there are natural breakpoints where I can simply purge all memory, and move on to the next file. I would like to do this without creating a new spark application for each input file.

Thank you in advance!


Solution

  • Spark 2.x

    You can use Catalog.clearCache:

    from pyspark.sql import SparkSession
    
    spark = SparkSession.builder.getOrCreate
    ...
    spark.catalog.clearCache()
    

    Spark 1.x

    You can use SQLContext.clearCache method which

    Removes all cached tables from the in-memory cache.

    from pyspark.sql import SQLContext
    from pyspark import SparkContext
    
    sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
    ...
    sqlContext.clearCache()