Search code examples
scalaapache-sparkapache-spark-sql

Spark dataframe process partitions in batches, N partitions at a time


I need to process Spark dataframe partitions in batches, N partitions at a time. For example if i have 1000 partitions in hive table, i need to process 100 partitions at a time.

I tried following approach

  1. Get partition list from hive table and find total count

  2. Get loop count using total_count/100

  3. Then

     for x in range(loop_count):
         files_list=partition_path_list[start_index:end_index]            
         df = spark.read.option("basePath", target_table_location).parquet(*files_list)
    

But this is not working as expected. Can anyone suggest a better method. Solution in Spark Scala is preferred


Solution

  • The for loop you have is just having x increment each time. That's why the start and end indices do not increment.

    Not sure why you mention Scala since your code is in Python. Here's an example with loop count being 1000.

    partitions_per_iteration = 100
    loop_count = 1000
    for start_index in range(0, loop_count, partitions_per_iteration):
        files_list=partition_path_list[start_index:start_index + partitions_per_iteration]
        df = spark.read.option("basePath", target_table_location).parquet(*files_list)
        
    

    In Scala, you can do a similar loop:

    total = 1000
    for {
        startIndex <- 0 until total by 100
    } {
        val filesList = partitionsPathList.slice(startIndex, startIndex + partitionsPerIteration)
        val df = ...
    }
    

    I think total or totalPartitions is a clearer variable name than "loop count".