Search code examples
pythonapache-sparkpysparkrdd

How do you get batches of rows from Spark using pyspark


I have a Spark RDD of over 6 billion rows of data that I want to use to train a deep learning model, using train_on_batch. I can't fit all the rows into memory so I would like to get 10K or so at a time to batch into chunks of 64 or 128 (depending on model size). I am currently using rdd.sample() but I don't think that guarantees I will get all rows. Is there a better method to partition the data to make it more manageable so that I can write a generator function for getting batches? My code is below:

data_df = spark.read.parquet(PARQUET_FILE)
print(f'RDD Count: {data_df.count()}') # 6B+
data_sample = data_df.sample(True, 0.0000015).take(6400) 
sample_df = data_sample.toPandas()

def get_batch():
  for row in sample_df.itertuples():
    # TODO: put together a batch size of BATCH_SIZE
    yield row

for i in range(10):
    print(next(get_batch()))

Solution

  • I don't believe spark let's you offset or paginate your data.

    But you can add an index and then paginate over that, First:

        from pyspark.sql.functions import lit
        data_df = spark.read.parquet(PARQUET_FILE)
        count = data_df.count()
        chunk_size = 10000
    
        # Just adding a column for the ids
        df_new_schema = data_df.withColumn('pres_id', lit(1))
        
        # Adding the ids to the rdd 
        rdd_with_index = data_df.rdd.zipWithIndex().map(lambda (row,rowId): (list(row) + [rowId+1]))
        
        # Creating a dataframe with index
        df_with_index = spark.createDataFrame(rdd_with_index,schema=df_new_schema.schema)
        
        # Iterating into the chunks
        for page_num in range(0, count+1, chunk_size):
            initial_page = page_num*chunk_size
            final_page = initial_page + chunk_size 
            where_query = ('pres_id > {0} and pres_id <= {1}').format(initial_page,final_page)
            chunk_df = df_with_index.where(where_query).toPandas()
            train_on_batch(chunk_df) # <== Your function here        
    

    This is not optimal it will badly leverage spark because of the use of a pandas dataframe but will solve your problem.

    Don't forget to drop the id if this affects your function.