Search code examples
pythonapache-sparkpysparkrdd

Iterative filter in spark doesn't seem to work


I'm trying to remove one by one elements of a RDD, but that doesn't work, as elements reappeared.

Here is a part of my code :

rdd = spark.sparkContext.parallelize([0,1,2,3,4])
for i in range(5):
    rdd=rdd.filter(lambda x:x!=i)
print(rdd.collect())
[0, 1, 2, 3]

So it seems that just the last filter is "remember". I was thinking that after this loop, the rdd would be empty.

However, I do not understand why, as every time I save the new rdd obtained by filter in "rdd", so shouldn't it keep all the transformations ? If not, how should I do ?

Thank you for pointing me out where I am wrong !


Solution

  • The result is actually correct - it is not a bug of Spark. Note that the lambda function is defined as x != i, and i is not substituted into the lambda function. So in each iteration of the for loop, the RDDs will look like

    rdd
    rdd.filter(lambda x: x != i)
    rdd.filter(lambda x: x != i).filter(lambda x: x != i)
    rdd.filter(lambda x: x != i).filter(lambda x: x != i).filter(lambda x: x != i)
    

    etc.

    Since the filters are all the same, and they will be substituted with the latest value of i, only one item is filtered away in each for loop iteration.

    To avoid this, you can use a partial function to make sure i is substituted into the function:

    from functools import partial
     
    rdd = spark.sparkContext.parallelize([0,1,2,3,4])
    for i in range(5):
        rdd = rdd.filter(partial(lambda x, i: x != i, i))
    
    print(rdd.collect())
    

    Or you can use reduce:

    from functools import reduce
    
    rdd = spark.sparkContext.parallelize([0,1,2])
    rdd = reduce(lambda r, i: r.filter(lambda x: x != i), range(3), rdd)
    print(rdd.collect())