I'm trying to remove one by one elements of a RDD, but that doesn't work, as elements reappeared.
Here is a part of my code :
rdd = spark.sparkContext.parallelize([0,1,2,3,4])
for i in range(5):
rdd=rdd.filter(lambda x:x!=i)
print(rdd.collect())
[0, 1, 2, 3]
So it seems that just the last filter is "remember". I was thinking that after this loop, the rdd would be empty.
However, I do not understand why, as every time I save the new rdd obtained by filter in "rdd", so shouldn't it keep all the transformations ? If not, how should I do ?
Thank you for pointing me out where I am wrong !
The result is actually correct - it is not a bug of Spark. Note that the lambda function is defined as x != i
, and i
is not substituted into the lambda function. So in each iteration of the for loop, the RDDs will look like
rdd
rdd.filter(lambda x: x != i)
rdd.filter(lambda x: x != i).filter(lambda x: x != i)
rdd.filter(lambda x: x != i).filter(lambda x: x != i).filter(lambda x: x != i)
etc.
Since the filters are all the same, and they will be substituted with the latest value of i
, only one item is filtered away in each for loop iteration.
To avoid this, you can use a partial function to make sure i
is substituted into the function:
from functools import partial
rdd = spark.sparkContext.parallelize([0,1,2,3,4])
for i in range(5):
rdd = rdd.filter(partial(lambda x, i: x != i, i))
print(rdd.collect())
Or you can use reduce
:
from functools import reduce
rdd = spark.sparkContext.parallelize([0,1,2])
rdd = reduce(lambda r, i: r.filter(lambda x: x != i), range(3), rdd)
print(rdd.collect())