Search code examples
apache-sparkpysparkdatabricks

Parallelize for-loop in pyspark; one table per iteration


I've got a few dozen spark tables in Databricks with sizes between ~1 and ~20 GB and want to execute a function on each of these tables. Since there is no interdependency between the results of each query, this should be easy to parallelize.

However I have no idea how to instruct pyspark to perform the following code in parallel. It just proceeds table after table.

This is a simple demo to show the structure of my code:

Cell 1 (create some demo tables):

tables = []
columns = list("abc")
for i in range(10):
    nrows = int(1E6)
    ncols = len(columns)
    data = np.random.rand(ncols * nrows).reshape((nrows, ncols))
    schema = ", ".join([f"{_}: float" for _ in columns])
    table = spark.createDataFrame(data=data, schema=schema)
    tables.append(table)

Cell 2 (perform an operation on each of them):

quantiles = {}
for i, table in enumerate(tables):
    quantiles[i] = table.approxQuantile(columns, [0.01, 0.99], relativeError=0.001)

Note: The demo is a bit simplified. In reality I have different columns on each table, so I can't just concat them.


Solution

  • from concurrent.futures import ThreadPoolExecutor
    
    def run_io_tasks_in_parallel(tasks):
        with ThreadPoolExecutor() as executor:
            running_tasks = [executor.submit(task) for task in tasks]
            for running_task in running_tasks:
                print(running_task.result())
    
    # Replace with actual logic.
    def process_table(table_name):
        return table_name, spark.sql(f"select count(*) from {table_name}").collect()[0][0]
    
    table_names = ['A', 'B', 'C']
    
    # Flatten the list of lambda functions before passing it to the executor
    tasks = [lambda x=name: process_table(x) for name in table_names]
    
    run_io_tasks_in_parallel(tasks)
    

    EDIT: Modified the example with provided sample code. Hope this helps.

    import numpy as np
    
    tables = []
    for i in range(10):
        columns = list("abc")
        nrows = int(1E6)
        ncols = len(columns)
        data = np.random.rand(ncols * nrows).reshape((nrows, ncols)).tolist()
        schema = ", ".join([f"{_}: float" for _ in columns])
        table = spark.createDataFrame(data=data, schema=schema)
        tables.append(table)
    
    from concurrent.futures import ThreadPoolExecutor
    quantiles = {}
    
    def run_io_tasks_in_parallel(tasks):
        with ThreadPoolExecutor() as executor:
            running_tasks = [executor.submit(task) for task in tasks]
            for running_task in running_tasks:
                index, result = running_task.result()
                quantiles[index] = result
    
    def process_table(index, table):
        return index, table.approxQuantile(columns, [0.01, 0.99], relativeError=0.001)
    
    # Flatten the list of lambda functions before passing it to the executor
    tasks = [
        lambda x=index, y=table: 
        process_table(x, y) for index, table in enumerate(tables)
    ]
    
    run_io_tasks_in_parallel(tasks)