Search code examples
pysparkapache-kafkaspark-structured-streaming

Can I "branch" stream into many and write them in parallel in pyspark?


I am receiving Kafka stream in pyspark. Currently I am grouping it by one set of fields and writing updates to database:

df = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", config["kafka"]["bootstrap.servers"]) \
        .option("subscribe", topic)

...

df = df \
        .groupBy("myfield1") \
        .agg(
            expr("count(*) as cnt"),
            min(struct(col("mycol.myfield").alias("mmm"), col("*"))).alias("minData")
        ) \
        .select("cnt", "minData.*") \
        .select(
            col("...").alias("..."),
            ...
            col("userId").alias("user_id")

query = df \
        .writeStream \
        .outputMode("update") \
        .foreachBatch(lambda df, epoch: write_data_frame(table_name, df, epoch)) \
        .start()

query.awaitTermination()

Can I take the same chain in the middle and create another grouping like

df2 = df \
        .groupBy("myfield2") \
        .agg(
            expr("count(*) as cnt"),
            min(struct(col("mycol.myfield").alias("mmm"), col("*"))).alias("minData")
        ) \
        .select("cnt", "minData.*") \
        .select(
            col("...").alias("..."),
            ...
            col("userId").alias("user_id")

and write it's ooutput into different place in parallel?

Where to call writeStream and awaitTermination?


Solution

  • Yes, you can branch a Kafka input stream into as many streaming queries as you like.

    You need to consider the following:

    1. query.awaitTermination is a blocking method, which means whatever code you are writing after this method will not be executed until this query gets terminated.
    2. Each "branched" streaming query will run in parallel and is it important that you define a checkpoint location in each of their writeStream calls.

    Overall, your code needs to have the following structure:

    df = spark \
            .readStream \
            .format("kafka") \
            .option("kafka.bootstrap.servers", config["kafka"]["bootstrap.servers"]) \
            .option("subscribe", topic) \
            .[...]
    
    # note that I changed the variable name to "df1"
    df1 = df \
        .groupBy("myfield1") \
        .[...]
    
    df2 = df \
        .groupBy("myfield2") \
        .[...]
    
    
    query1 = df1 \
            .writeStream \
            .outputMode("update") \
            .option("checkpointLocation", "/tmp/checkpointLoc1") \
            .foreachBatch(lambda df, epoch: write_data_frame(table_name, df1, epoch)) \
            .start()
    
    query2 = df2 \
            .writeStream \
            .outputMode("update") \
            .option("checkpointLocation", "/tmp/checkpointLoc2") \
            .foreachBatch(lambda df, epoch: write_data_frame(table_name, df2, epoch)) \
            .start()
    
    spark.streams.awaitAnyTermination
    

    Just an additional remark: In the code you are showing, you are overwriting df, so the derivation of df2 might not get you the results as you were intended.