Search code examples
azureapache-sparkpysparkdatabricksazure-databricks

Executing a function in parallel for multiple arguments on Databricks


I am using Databricks and PySpark, technologies that I am quite new to.

I have built a function that trains and predicts k-means models (from the sklearn library - I am aware that pyspark.ml exists, but I was instructed to use sklearn). My function performs some operations with Spark DataFrames and then proceeds with training and prediction. Something like this:

def train_and_predict(cat):
    #import (via spark.sql) the necessary tables for that specific category
    #some operations on the spark df, some joins, some filters...
    #training and saving to table
    #predicting and saving to table

My question is this: I want to run this function multiple times for different categories, and I would like to process many of them in parallel. Currently, I am using the joblib library, but I suspect that joblib is not fully leveraging the capabilities of the Spark cluster. How can I run this function in parallel to take full advantage of the Spark cluster?


Solution

  • To train and predict for different categories parallelly you need parallelize the list of categories and call the function as below.

    categories = ['category1', 'category2', 'category3']
    
    category_rdd = sc.parallelize(categories)
    
    def train_and_predict(category):
        spark.sql(f"select  *  from  data  where category={category}")
        return category
        
    category_rdd.map(train_and_predict).collect()
    
    

    But, since you get the data according to category via spark.sql you will get error that is spark context can only be used on the driver.

    So, if your referencing spark context inside function and to utilize full cluster it's not possible.

    My idea is to create your functionality in one notebook and call that notebook parallelly passing different categories.

    First, create a new notebook add your code like below.

    dbutils.widgets.text("category","00")
    
    cat = dbutils.widgets.get("category") 
    
    #import (via spark.sql) the necessary tables for that specific category
    #some operations on the spark df, some joins, some filters...
    #training and saving to table
    #predicting and saving to table
    
    dbutils.notebook.exit(cat)
    

    And call it like below.

    from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
    
    categories = ["category1", "category2", "category3"]
    
    def train_data(cat):
        res = dbutils.notebook.run("/Users/[email protected]/train", timeout_seconds=300, arguments={"category":cat})
        print(res)
    
    with ThreadPoolExecutor() as executor:
        results = executor.map(train_data, categories)
    

    enter image description here

    Make sure you have enough resources for training because this runs only on driver node to utilize full cluster you need to train in spark context using Py spark ml.