Search code examples
pythonapache-sparktestingpysparkhivecontext

How to prevent memory leak when testing with HiveContext in PySpark


I use pyspark to do some data processing and leverage HiveContext for the window function.

In order to test the code, I use TestHiveContext, basically copying the implementation from pyspark source code:

https://spark.apache.org/docs/preview/api/python/_modules/pyspark/sql/context.html

@classmethod
def _createForTesting(cls, sparkContext):
    """(Internal use only) Create a new HiveContext for testing.

    All test code that touches HiveContext *must* go through this method. Otherwise,
    you may end up launching multiple derby instances and encounter with incredibly
    confusing error messages.
    """
    jsc = sparkContext._jsc.sc()
    jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc)
    return cls(sparkContext, jtestHive)

My tests then inherit the base class which can access the context.

This worked fine for a while. However, I started noticing some intermittent process running out of memory issues as I added more tests. Now I can't run the test suite without a failure.

"java.lang.OutOfMemoryError: Java heap space"

I explicitly stop the spark context after each test is run, but that does not appear to kill the HiveContext. Thus, I believe it keeps creating new HiveContexts everytime a new test is run and doesn't remove the old one which results in the memory leak.

Any suggestions for how to teardown the base class such that it kills the HiveContext?


Solution

  • If you're happy to use a singleton to hold the Spark/Hive context in all your tests, you can do something like the following.

    test_contexts.py:

    _test_spark = None
    _test_hive = None
    
    def get_test_spark():
        if _test_spark is None:
            # Create spark context for tests.
            # Not really sure what's involved here for Python.
            _test_spark = ...
        return _test_spark
    
    def get_test_hive():
        if _test_hive is None:
            sc = get_test_spark()
            jsc = test_spark._jsc.sc()
            _test_hive = sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc)
        return _test_hive
    

    And then you just import these functions in your tests.

    my_test.py:

    from test_contexts import get_test_spark, get_test_hive
    
    def test_some_spark_thing():
        sc = get_test_spark()
        sqlContext = get_test_hive()
        # etc