I am receiving Kafka stream in pyspark. Currently I am grouping it by one set of fields and writing updates to database:
df = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", config["kafka"]["bootstrap.servers"]) \
.option("subscribe", topic)
...
df = df \
.groupBy("myfield1") \
.agg(
expr("count(*) as cnt"),
min(struct(col("mycol.myfield").alias("mmm"), col("*"))).alias("minData")
) \
.select("cnt", "minData.*") \
.select(
col("...").alias("..."),
...
col("userId").alias("user_id")
query = df \
.writeStream \
.outputMode("update") \
.foreachBatch(lambda df, epoch: write_data_frame(table_name, df, epoch)) \
.start()
query.awaitTermination()
Can I take the same chain in the middle and create another grouping like
df2 = df \
.groupBy("myfield2") \
.agg(
expr("count(*) as cnt"),
min(struct(col("mycol.myfield").alias("mmm"), col("*"))).alias("minData")
) \
.select("cnt", "minData.*") \
.select(
col("...").alias("..."),
...
col("userId").alias("user_id")
and write it's ooutput into different place in parallel?
Where to call writeStream
and awaitTermination
?
Yes, you can branch a Kafka input stream into as many streaming queries as you like.
You need to consider the following:
query.awaitTermination
is a blocking method, which means whatever code you are writing after this method will not be executed until this query
gets terminated.Overall, your code needs to have the following structure:
df = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", config["kafka"]["bootstrap.servers"]) \
.option("subscribe", topic) \
.[...]
# note that I changed the variable name to "df1"
df1 = df \
.groupBy("myfield1") \
.[...]
df2 = df \
.groupBy("myfield2") \
.[...]
query1 = df1 \
.writeStream \
.outputMode("update") \
.option("checkpointLocation", "/tmp/checkpointLoc1") \
.foreachBatch(lambda df, epoch: write_data_frame(table_name, df1, epoch)) \
.start()
query2 = df2 \
.writeStream \
.outputMode("update") \
.option("checkpointLocation", "/tmp/checkpointLoc2") \
.foreachBatch(lambda df, epoch: write_data_frame(table_name, df2, epoch)) \
.start()
spark.streams.awaitAnyTermination
Just an additional remark: In the code you are showing, you are overwriting df
, so the derivation of df2
might not get you the results as you were intended.