Search code examples
pysparkrddbatch-processing

How to collect elements in a RDD by batches


I have a pyspark RDD which has ~2 million elements. I cannot collect them all at once, because it causes an OutOfMemoryError exception.

How can I collect them in batches?

This is a potential solution, but I suspect there is better: collect a batch (using take, https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.RDD.take.html#pyspark.RDD.take), then remove all elements from the RDD in that batch (using filter, https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.RDD.filter.html#pyspark.RDD.filter, but I suspect there is a better way), reiterate until no elements are collected.


Solution

  • I'm not sure its a good solution, but you can zip your rdd with an index, and then filter on that index to collect the items in batches:

    big_rdd = spark.sparkContext.parallelize([str(i) for i in range(0, 100)])
    big_rdd_with_index = big_rdd.zipWithIndex()
    batch_size = 10
    batches = []
    for i in range(0, 100, batch_size):
      batches.append(big_rdd_with_index.filter(lambda element: i <= element[1] < i + batch_size).map(lambda element: element[0]).collect())
    for l in batches:
      print(l)
    

    Output:

    ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    ['10', '11', '12', '13', '14', '15', '16', '17', '18', '19']
    ['20', '21', '22', '23', '24', '25', '26', '27', '28', '29']
    ['30', '31', '32', '33', '34', '35', '36', '37', '38', '39']
    ['40', '41', '42', '43', '44', '45', '46', '47', '48', '49']
    ['50', '51', '52', '53', '54', '55', '56', '57', '58', '59']
    ['60', '61', '62', '63', '64', '65', '66', '67', '68', '69']
    ['70', '71', '72', '73', '74', '75', '76', '77', '78', '79']
    ['80', '81', '82', '83', '84', '85', '86', '87', '88', '89']
    ['90', '91', '92', '93', '94', '95', '96', '97', '98', '99']