Problem: I am receiving multiple table/schema data in a single stream. Now after segregating the data I am opening a parallel write stream for each table.
The function I used in forEachBatch is:
def writeToAurora(df, batch_id, tableName):
df = df.persist()
stagingTable = f'{str(tableName.lower())}_delta'
df.write \
.mode("overwrite") \
.format("jdbc") \
.option("truncate", "true") \
.option("driver", DB_conf['DRIVER']) \
.option("batchsize", 1000) \
.option("url", DB_conf['URL']) \
.option("dbtable", stagingTable) \
.option("user", DB_conf['USER_ID']) \
.option("password", DB_conf['PASSWORD']) \
.save()
df.unpersist()
The logic to open multiple writestreams is
data_df = spark.readStream.format("kinesis") \
.option("streamName", stream_name) \
.option("startingPosition", initial_position) \
.load()
#Distinguishing table wise df
distinctTables = ['Table1', 'Table2', 'Table3']
tablesDF = {table: data_df.filter(f"TableName = '{table}'") for table in distinctTables}
#Processing Each Table
for table, tableDF in tablesDF.items():
df = tableDF.withColumn('csvData', F.from_csv('finalData', schema=tableSchema[table], options={'sep': '|','quote': '"'}))\
.select('csvData.*')
vars()[table+'_query'] = df.writeStream\
.trigger(processingTime='120 seconds') \
.foreachBatch(lambda fdf, batch_id: writeToAurora(fdf, batch_id, table)) \
.option("checkpointLocation", f"s3://{bucket}/temporary/checkpoint/{table}")\
.start()
for table in tablesDF.keys():
eval(table+'_query').awaitTermination()
Issue: Now when running the above code sometimes the table1 is getting loaded in table2 and the order is differenet each time the code runs. The order is not maintained between the dataframe and the table in which it should be loaded.
Need help on understanding why this is happening.
This is caused by late binding for your lambda function in the foreachBatch
method.
Here's an example. This will try and write all tables to "t2", and fails (actually only writing the "t2" table, but writing the "t0" data:
from pyspark.sql.functions import *
from pyspark.sql import *
def writeToTable(df, epochId, table_name):
df.write.mode("overwrite").saveAsTable(f"custanwo.dsci.stream_test_{table_name}")
data_df = spark.readStream.format("rate").load()
data_df = (data_df
.selectExpr("value % 10 as key")
.groupBy("key")
.count()
.withColumn("t", concat(lit("t"),(col("key")%3).astype("string")))
)
table_names = ["t0", "t1", "t2"]
table_df = {t: data_df.filter(f"t = '{t}'") for t in table_names}
for t, df in table_df.items():
vars()[f"{t}_query"] = (df
.writeStream
.foreachBatch(lambda df, epochId: writeToTable(df, epochId, t))
.outputMode("update")
.start()
)
To resolve this there are a few options. One is using partial
:
from functools import partial
def writeToTable(df, epochId, table_name):
df.write.mode("overwrite").saveAsTable(f"custanwo.dsci.stream_test_{table_name}")
data_df = spark.readStream.format("rate").load()
data_df = (data_df
.selectExpr("value % 10 as key")
.groupBy("key")
.count()
.withColumn("t", concat(lit("t"),(col("key")%3).astype("string")))
)
table_names = ["t0", "t1", "t2"]
table_df = {t: data_df.filter(f"t = '{t}'") for t in table_names}
for t, df in table_df.items():
vars()[f"{t}_query"] = (df
.writeStream
.foreachBatch(partial(writeToTable, table_name=t))
.outputMode("update")
.start()
)
In your code, rewrite your writeStream
to:
vars()[table+'_query'] = df.writeStream\
.trigger(processingTime='120 seconds') \
.foreachBatch(partial(writeToAurora, tableName = table)) \
.option("checkpointLocation", f"s3://{bucket}/temporary/checkpoint/{table}")\
.start()