I have a pyspark RDD which has ~2 million elements. I cannot collect them all at once, because it causes an OutOfMemoryError
exception.
How can I collect them in batches?
This is a potential solution, but I suspect there is better: collect a batch (using take
, https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.RDD.take.html#pyspark.RDD.take), then remove all elements from the RDD in that batch (using filter
, https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.RDD.filter.html#pyspark.RDD.filter, but I suspect there is a better way), reiterate until no elements are collected.
I'm not sure its a good solution, but you can zip your rdd with an index, and then filter on that index to collect the items in batches:
big_rdd = spark.sparkContext.parallelize([str(i) for i in range(0, 100)])
big_rdd_with_index = big_rdd.zipWithIndex()
batch_size = 10
batches = []
for i in range(0, 100, batch_size):
batches.append(big_rdd_with_index.filter(lambda element: i <= element[1] < i + batch_size).map(lambda element: element[0]).collect())
for l in batches:
print(l)
Output:
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
['10', '11', '12', '13', '14', '15', '16', '17', '18', '19']
['20', '21', '22', '23', '24', '25', '26', '27', '28', '29']
['30', '31', '32', '33', '34', '35', '36', '37', '38', '39']
['40', '41', '42', '43', '44', '45', '46', '47', '48', '49']
['50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
['60', '61', '62', '63', '64', '65', '66', '67', '68', '69']
['70', '71', '72', '73', '74', '75', '76', '77', '78', '79']
['80', '81', '82', '83', '84', '85', '86', '87', '88', '89']
['90', '91', '92', '93', '94', '95', '96', '97', '98', '99']