I am using Databricks and PySpark, technologies that I am quite new to.
I have built a function that trains and predicts k-means models (from the sklearn library - I am aware that pyspark.ml exists, but I was instructed to use sklearn). My function performs some operations with Spark DataFrames and then proceeds with training and prediction. Something like this:
def train_and_predict(cat):
#import (via spark.sql) the necessary tables for that specific category
#some operations on the spark df, some joins, some filters...
#training and saving to table
#predicting and saving to table
My question is this: I want to run this function multiple times for different categories, and I would like to process many of them in parallel. Currently, I am using the joblib library, but I suspect that joblib is not fully leveraging the capabilities of the Spark cluster. How can I run this function in parallel to take full advantage of the Spark cluster?
To train and predict for different categories parallelly you need parallelize the list of categories and call the function as below.
categories = ['category1', 'category2', 'category3']
category_rdd = sc.parallelize(categories)
def train_and_predict(category):
spark.sql(f"select * from data where category={category}")
return category
category_rdd.map(train_and_predict).collect()
But, since you get the data according to category via spark.sql
you will get error
that is spark context can only be used on the driver.
So, if your referencing spark context inside function and to utilize full cluster it's not possible.
My idea is to create your functionality in one notebook and call that notebook parallelly passing different categories.
First, create a new notebook add your code like below.
dbutils.widgets.text("category","00")
cat = dbutils.widgets.get("category")
#import (via spark.sql) the necessary tables for that specific category
#some operations on the spark df, some joins, some filters...
#training and saving to table
#predicting and saving to table
dbutils.notebook.exit(cat)
And call it like below.
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
categories = ["category1", "category2", "category3"]
def train_data(cat):
res = dbutils.notebook.run("/Users/[email protected]/train", timeout_seconds=300, arguments={"category":cat})
print(res)
with ThreadPoolExecutor() as executor:
results = executor.map(train_data, categories)
Make sure you have enough resources for training because this runs only on driver node to utilize full cluster you need to train in spark context using Py spark ml.