Search code examples
apache-sparkpysparkspark-streamingspark-structured-streaming

Preventing Spark from storing state in stream/stream joins


I have two streaming datasets, let's call them fastStream and slowStream.

The fastStream is a streaming dataset that I am consuming from Kafka via the structured streaming API. I am expecting to receive potentially thousands of messages a second.

The slowStream is actually a reference (or lookup) table that is being 'upserted' by another stream and contains data that I want to join on to each message in the fastStream before I save the records to a table. The slowStream is only updated when someone changes the metadata, which can happen at any time but we would expect to change maybe once every few days.

Each record in the fastStream will have exactly one corresponding message in the slowStream and I essentially want to make that join happen immediately with whatever data is in the slowStream table. I don't want to wait to see if a potential match could occur if new data arrives in the slowStream.

The problem that I have is that according to the Spark docs:

Hence, for both the input streams, we buffer past input as streaming state, so that we can match every future input with past input and accordingly generate joined results.

I have tried adding a watermark to the fastStream but I think it has no effect since the docs indicate that the watermarked columns need to be referenced in the join

Ideally I would write something like:

# Apply a watermark to the fast stream
fastStream = spark.readStream \
.format("delta") \
.load("dbfs:/mnt/some_file/fastStream") \
.withWatermark("timestamp", "1 hour") \
.alias("fastStream")

# The slowStream cannot be watermarked since it is only slowly changing
slowStream = spark.readStream \
.format("delta") \
.load("dbfs:/mnt/some_file/slowStream") \
.alias("slowStream")

# Prevent the join from buffering the fast stream by 'telling' spark that there will never be new matches.
fastStream.join( 
  slowStrean,
  expr(""" 
    fastStream.slow_id = slowStream.id
    AND fastStream.timestamp > watermark
    """
  ),
  "inner"
).select("fastStream.*", "slowStream.metadata")

But I don't think you can reference the watermark in the SQL expression.

Essentially, while I'm happy to have the slowStream buffered (so the whole table is in memory) I can't have the fastStream buffered as this table will quickly consume all memory. Instead, I would simply like to drop messages from the fastStream that aren't matched instead of retaining them to see if they might match in future.

Any help very gratefully appreciated.


Solution

  • Answering my own question with what I ended up going with. It's certainly not ideal but for all my searching, there doesn't seem to be the control within Spark structured streaming to address this use case.

    So my solution was to read the dataset and conduct the join inside a foreachBatch. This way I prevent Spark from storing a ton of unnecessary state and get the joins conducted immediately. On the downside, there seems to be no way to incrementally read a stream table so instead, I am re-reading the entire table every time...

    def join_slow_stream(df, batchID):
        
      # Read as a table rather than a stream
      slowdf = spark.read \
        .format("delta") \
        .load("dbfs:/mnt/some_file/slowStream") \
        .alias("slowStream")
      
      out_df = df.join(
        slowdf,
        expr(""" 
          fastStream.slow_id = slowStream.id
          """
        ),
        "inner"
      ).select("fastStream.*", "slowStream.metadata")
    
      # write data to database
      db_con.write(out_df)
    
    
    
    fastStream.writeStream.foreachBatch(join_slow_stream)