Search code examples
apache-sparkpysparkspark-structured-streaming

Problem with Kafka offsets in Apache Spark 3.5 structured streaming in Batch Mode


I am writing a batch query which uses Kafka as a source, according to the Kafka integration guide and want to submit this batch periodically, say once a day, to process records which have been added since the last run. During testing while running pyspark I notice that every time the batch runs, it reads all the records, not just the ones that have been added since the last run. My code is approximately as follows.

The question is: what do I have to change, so that each time it runs, I only process new Kafka records?

builder = (pyspark.sql.SparkSession.builder.appName("MyApp")
            .master("local[*]")
            .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
            .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
            .config("spark.sql.execution.arrow.pyspark.enabled", "true")
            .config("spark.hadoop.fs.s3a.access.key", s3a_access_key)
            .config("spark.hadoop.fs.s3a.secret.key", s3a_secret_key)
            .config("spark.hadoop.fs.s3a.endpoint", s3a_host_port)
            .config("spark.hadoop.fs.s3a.path.style.access", "true")
            .config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false")
            .config("spark.databricks.delta.retentionDurationCheck.enabled", "false")
            .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
            .config("spark.driver.extraJavaOptions", "-Dlog4j.configuration=file:/data/custom-log4j.properties")
           )

my_packages = [
                   # "io.delta:delta-spark_2.12:3.0.0", -> no need, since configure_spark_with_delta_pip below adds it
                   "org.apache.hadoop:hadoop-aws:3.3.4",
                   "org.apache.hadoop:hadoop-client-runtime:3.3.4",
                   "org.apache.hadoop:hadoop-client-api:3.3.4",
                   "io.delta:delta-contribs_2.12:3.0.0",
                   "io.delta:delta-hive_2.12:3.0.0",
                   "com.amazonaws:aws-java-sdk-bundle:1.12.603",
                   "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0",
               ]

# Create a Spark instance with the builder
# As a result, you now can read and write Delta tables
spark = configure_spark_with_delta_pip(builder, extra_packages=my_packages).getOrCreate()


kdf = (spark
        .read
        .format("kafka")
        .option("kafka.bootstrap.servers", kafka_bootstrap_servers)
        .option("kafka.security.protocol", kafka_security_protocol)
        .option("kafka.sasl.mechanism", "SCRAM-SHA-256")
        .option("kafka.sasl.jaas.config", f"org.apache.kafka.common.security.scram.ScramLoginModule required username=\"{kafka_username}\" password=\"{kafka_password}\";")
        .option("includeHeaders", "true")
        .option("subscribe", "filebeat")
        .option("checkpointLocation", "s3a://checkpointlocation/")
        .load())

kdf = kdf.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers", "CAST(topic AS STRING)", "CAST(partition AS STRING)", "CAST(offset AS STRING)")

out = kdf...

(out.select(["message", "partition", "offset"])
                     .show(
                        truncate=False,
                        n=MAX_JAVA_INT
                    ))


spark.stop()

This outputs a table where I can see that the same offsets are being processed with each run.


Solution

  • You are reading the topic in batch mode link which sets startingOffsets = earliest by default. Also the checkpointLocation has no effect in batch mode, you have to read in streaming mode spark.readStream... and the processed offsets will be stored there.

    Example Application:

    source_df = (
        spark
        .readStream
        .format('kafka')
        .options(**{
            'subscribe': 'some_topic',
            'startingOffsets': 'earliest',
        })
        .load()
    )
    writer = (
        source_df
        .writeStream
        .format('parquet')
        .option('path', '/some_path')
        .outputMode('append')
        .option('checkpointLocation', '<some-location>')
        .trigger(availableNow=True)
    )
    
    streaming_query = writer.start()
    streaming_query.awaitTermination()
    spark.stop()
    

    First iteration of application:

    1. The checkpointLocation is empty, so Spark will read from earliest offsets until the current offsets.
    2. The reached offsets will be stored in the checkpointLocation
    3. The applications stops.

    Second iteration of application

    1. The checkpointLocation is not empty, so Spark will start reading from the offsets there until the current offsets.
    2. The reached offsets will be stored in checkpointLocation.
    3. The applications stops.

    Note that .option('checkpointLocation', '<some-location>') has to be called on the DataStreamWriter, NOT on the DataStreamReader.