I'm playing around with Spark Structured Streaming but stumbling upon an issue for which I don't see the root cause and resolution to the problem.
I have defined a class Reader
including a function read_events
that receives a transactions_df
data frame as input, performs a read to a delta table and joins the output of this read with the incoming data frame. The function returns a data frame.
Code of the Reader class including the transformation function:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
class Reader():
def __init__(self, spark):
self.spark = spark
self.events_table = "events_table"
def read_events(self, transactions_df):
events_df = self.spark.read.table(self.events_table)
result_df = (
transactions_df.alias("T")
.join(
events_df.alias("E"),
(col("E.col1") == col("T.col1"))
.select(
col("E.col1"),
col("E.col2"),
)
)
return result_df
Executing this function standalone works as expected. Below statement gives expected output.
input_df = spark.read.table("transactions_table")
Reader(spark).read_events(input_df).display()
But when I try to incorporate this function on the output of a micro-batch from the streaming query, inside a foreachBatch, I receive the following error message:
STREAMING_CONNECT_SERIALIZATION_ERROR] Cannot serialize the function `foreachBatch`. If you accessed the Spark session, or a DataFrame defined outside of the function, or any object that contains a Spark session, please be aware that they are not allowed in Spark Connect. For `foreachBatch`, please access the Spark session using `df.sparkSession`, where `df` is the first parameter in your `foreachBatch` function. For `StreamingQueryListener`, please access the Spark session using `self.spark`. For details please check out the PySpark doc for `foreachBatch` and `StreamingQueryListener`.
Code of the Streaming class:
from pyspark.sql import DataFrame
class Streaming():
def __init__(self, spark):
self.spark = spark
self.transactions_table= "transactions_table"
self.transactions_table_stream_checkpoint= "checkpoint_path"
self.reader= Reader(spark)
def process_batch_of_messages(self, df, batch_id):
result_df = self.reader.read_events(df)
print(f"For batch {batch_id} we have {result_df.count()} records.")
def launch(self):
(
self.spark.readStream.format("delta")
.option("skipChangeCommits", "true")
.table(self.transactions_table)
.writeStream.option("checkpointLocation", self.transactions_table_stream_checkpoint)
.foreachBatch(
lambda transactions, batch_id: self.process_batch_of_messages(
df=transactions, batch_id=batch_id
)
)
.start()
)
def entrypoint():
stream = Streaming(spark)
stream.launch()
if __name__ == "__main__":
entrypoint()
Any help or suggestions in the right direction on the root cause and possible resolutions would be greatly appreciated!
The reason why you are seeing this error is you put the local spark session into the foreachBatch function. This happens when you do the initialization of the Reader
:
self.reader= Reader(spark)
And in the read_events method for Reader, it is used:
events_df = self.spark.read.table(self.events_table)
As the error says,
If you accessed the Spark session, or a DataFrame defined outside of the function, or any object that contains a Spark session, please be aware that they are not allowed in Spark Connect. For foreachBatch, please access the Spark session using df.sparkSession, where df is the first parameter in your foreachBatch function
You need to avoid putting local spark session to the foreachBatch function, and in Reader.read_events
, access the spark session with:
events_df = transactions_df.sparkSession.read.table(self.events_table)
This is sadly a breaking change introduced for Spark Connect, code that worked before needs some change before having it work in Spark Conenct.