I've got a few dozen spark tables in Databricks with sizes between ~1 and ~20 GB and want to execute a function on each of these tables. Since there is no interdependency between the results of each query, this should be easy to parallelize.
However I have no idea how to instruct pyspark to perform the following code in parallel. It just proceeds table after table.
This is a simple demo to show the structure of my code:
Cell 1 (create some demo tables):
tables = []
columns = list("abc")
for i in range(10):
nrows = int(1E6)
ncols = len(columns)
data = np.random.rand(ncols * nrows).reshape((nrows, ncols))
schema = ", ".join([f"{_}: float" for _ in columns])
table = spark.createDataFrame(data=data, schema=schema)
tables.append(table)
Cell 2 (perform an operation on each of them):
quantiles = {}
for i, table in enumerate(tables):
quantiles[i] = table.approxQuantile(columns, [0.01, 0.99], relativeError=0.001)
Note: The demo is a bit simplified. In reality I have different columns on each table, so I can't just concat them.
from concurrent.futures import ThreadPoolExecutor
def run_io_tasks_in_parallel(tasks):
with ThreadPoolExecutor() as executor:
running_tasks = [executor.submit(task) for task in tasks]
for running_task in running_tasks:
print(running_task.result())
# Replace with actual logic.
def process_table(table_name):
return table_name, spark.sql(f"select count(*) from {table_name}").collect()[0][0]
table_names = ['A', 'B', 'C']
# Flatten the list of lambda functions before passing it to the executor
tasks = [lambda x=name: process_table(x) for name in table_names]
run_io_tasks_in_parallel(tasks)
EDIT: Modified the example with provided sample code. Hope this helps.
import numpy as np
tables = []
for i in range(10):
columns = list("abc")
nrows = int(1E6)
ncols = len(columns)
data = np.random.rand(ncols * nrows).reshape((nrows, ncols)).tolist()
schema = ", ".join([f"{_}: float" for _ in columns])
table = spark.createDataFrame(data=data, schema=schema)
tables.append(table)
from concurrent.futures import ThreadPoolExecutor
quantiles = {}
def run_io_tasks_in_parallel(tasks):
with ThreadPoolExecutor() as executor:
running_tasks = [executor.submit(task) for task in tasks]
for running_task in running_tasks:
index, result = running_task.result()
quantiles[index] = result
def process_table(index, table):
return index, table.approxQuantile(columns, [0.01, 0.99], relativeError=0.001)
# Flatten the list of lambda functions before passing it to the executor
tasks = [
lambda x=index, y=table:
process_table(x, y) for index, table in enumerate(tables)
]
run_io_tasks_in_parallel(tasks)