Search code examples
apache-sparkpysparkside-effects

Unexpected results when making dicts and lists of RDDs in pyspark


Below is a simple pyspark script that tries to split an RDD into a dictionary containing several RDDs.

As the sample run shows, the script only works if we do a collect() on the intermediate RDDs as they are created. Of course I would not want to do that in practice, since it doesn't scale.

What's really strange is, I'm not assigning the intermediate collect() results to any variable. So the difference in behavior is due solely to a hidden side-effect of the computation triggered by the collect() call.

Spark is supposed to be a very functional framework with minimal side effects. Why is it only possible to get the desired behavior by triggering some mysterious side effect using collect()?

The run below is with Spark 1.5.2, Python 2.7.10, and IPython 4.0.0.

spark_script.py

from pprint import PrettyPrinter
pp = PrettyPrinter(indent=4).pprint
logger = sc._jvm.org.apache.log4j
logger.LogManager.getLogger("org"). setLevel( logger.Level.ERROR )
logger.LogManager.getLogger("akka").setLevel( logger.Level.ERROR )

def split_RDD_by_key(rdd, key_field, key_values, collect_in_loop=False):
    d = dict()
    for key_value in key_values:
        d[key_value] = rdd.filter(lambda row: row[key_field] == key_value)
        if collect_in_loop:
            d[key_value].collect()
    return d
def print_results(d):
    for k in d:
        print k
        pp(d[k].collect())    

rdd = sc.parallelize([
    {'color':'red','size':3},
    {'color':'red', 'size':7},
    {'color':'red', 'size':8},    
    {'color':'red', 'size':10},
    {'color':'green', 'size':9},
    {'color':'green', 'size':5},
    {'color':'green', 'size':50},    
    {'color':'blue', 'size':4},
    {'color':'purple', 'size':6}])
key_field = 'color'
key_values = ['red', 'green', 'blue', 'purple']

print '### run WITH collect in loop: '
d = split_RDD_by_key(rdd, key_field, key_values, collect_in_loop=True)
print_results(d)
print '### run WITHOUT collect in loop: '
d = split_RDD_by_key(rdd, key_field, key_values, collect_in_loop=False)
print_results(d)

Sample run in IPython shell

In [1]: execfile('spark_script.py')
### run WITH collect in loop: 
blue
[{   'color': 'blue', 'size': 4}]
purple
[{   'color': 'purple', 'size': 6}]
green
[   {   'color': 'green', 'size': 9},
    {   'color': 'green', 'size': 5},
    {   'color': 'green', 'size': 50}]
red
[   {   'color': 'red', 'size': 3},
    {   'color': 'red', 'size': 7},
    {   'color': 'red', 'size': 8},
    {   'color': 'red', 'size': 10}]
### run WITHOUT collect in loop: 
blue
[{   'color': 'purple', 'size': 6}]
purple
[{   'color': 'purple', 'size': 6}]
green
[{   'color': 'purple', 'size': 6}]
red
[{   'color': 'purple', 'size': 6}]

Solution

  • Short Answer

    As it turns out, this is not so much a Spark issue as a tricky Python feature called late-binding closures. A quick hack to force early binding (the desired behavior in this case) is to add a default argument:

    lambda row, key_value=key_value: row[key_field] == key_value
    

    The other way is with functools.partial.

    Long Answer

    When a function is defined in Python, any parameters that come from outside the function are retrieved from the defining environment (lexical scoping), and this is done when the function is evaluated, not when it is defined (late binding). So, in the lambda function used by the filter transformation, the value of key_value is not determined until the function is evaluated.

    You can start to see the danger here: key_value takes several values in the loop of split_RDDs_by_key(). What if, when lambda is evaluated, key_value no longer has the value we wanted? Functions are often evaluated long after they are defined, especially when working with RDDs. Due to the lazy computational semantics of RDDs, lambda will not be evaluated until an action is called to retrieve data, such as collect() or take().

    In split_RDD_by_key() we are looping over key_values and making a new RDD for each value. When collect_in_loop=False, there is no collect() until after split_RDD_by_key() is done executing. By then, the loop inside is complete, and key_value now has the value 'purple', from the last iteration of the loop. When all the lambdas in all the RDDs from split_RDD_by_key() are evaluated, they all set key_value to 'purple' and retrieve the 'purple' rows of the RDD.

    When collect_in_loop=True, we do a collect() on each iteration, causing the lambda to be evaluated in the same iteration where it was defined, and we get the key_value we expect.

    This example actually reveals an interesting, subtle detail about python closures. When the in-loop collect() triggers the evaluation of the lambda, the lambda binds a value. But what does the lambda do when evaluated with later collect() statements, when key_value has changed (in the defining environment) from what it was at the first lambda evaluation? This example shows that all evaluations of a function closure are based on the binding from the first evaluation. "Calling means closure, once and for all."