I have a Spark RDD of over 6 billion rows of data that I want to use to train a deep learning model, using train_on_batch. I can't fit all the rows into memory so I would like to get 10K or so at a time to batch into chunks of 64 or 128 (depending on model size). I am currently using rdd.sample() but I don't think that guarantees I will get all rows. Is there a better method to partition the data to make it more manageable so that I can write a generator function for getting batches? My code is below:
data_df = spark.read.parquet(PARQUET_FILE)
print(f'RDD Count: {data_df.count()}') # 6B+
data_sample = data_df.sample(True, 0.0000015).take(6400)
sample_df = data_sample.toPandas()
def get_batch():
for row in sample_df.itertuples():
# TODO: put together a batch size of BATCH_SIZE
yield row
for i in range(10):
print(next(get_batch()))
I don't believe spark let's you offset or paginate your data.
But you can add an index and then paginate over that, First:
from pyspark.sql.functions import lit
data_df = spark.read.parquet(PARQUET_FILE)
count = data_df.count()
chunk_size = 10000
# Just adding a column for the ids
df_new_schema = data_df.withColumn('pres_id', lit(1))
# Adding the ids to the rdd
rdd_with_index = data_df.rdd.zipWithIndex().map(lambda (row,rowId): (list(row) + [rowId+1]))
# Creating a dataframe with index
df_with_index = spark.createDataFrame(rdd_with_index,schema=df_new_schema.schema)
# Iterating into the chunks
for page_num in range(0, count+1, chunk_size):
initial_page = page_num*chunk_size
final_page = initial_page + chunk_size
where_query = ('pres_id > {0} and pres_id <= {1}').format(initial_page,final_page)
chunk_df = df_with_index.where(where_query).toPandas()
train_on_batch(chunk_df) # <== Your function here
This is not optimal it will badly leverage spark because of the use of a pandas dataframe but will solve your problem.
Don't forget to drop the id
if this affects your function.