Search code examples

Spark Stateful Structured Streaming: State getting too big in mapGroupsWithState

I am trying to use mapGroupsWithState method for stateful structured streaming for my incoming stream of data. But the problem that I face is that the key I am choosing for groupByKey makes my state too big too fast. The obvious way out would be to change the key but the business logic I wish to apply in update method, requires the key to exactly same as I have it right now OR if it is possible, access GroupState for all keys.

For example, I have a stream of data coming in from various Organizations and typically an organization contains userId, personId etc. Please see the code below:

val stream: Dataset[User] =[User]
val noTimeout = GroupStateTimeout.NoTimeout
val statisticStream = stream
    .groupByKey(key => key.orgId)

val df = statisticStream.toDF()

val query = df
    .option("checkpointLocation", s"$checkpointLocation/$name")
    .foreach(new UserCountWriter(spark.sparkContext.getConf))
    .trigger(Trigger.ProcessingTime(Duration.apply("10 seconds")))

case classes:

case class User(
  orgId: Long,
  profileId: Long,
  userId: Long)

case class UserStatistic(
  orgId: Long,
  known: Long,
  uknown: Long,
  userSeq: Seq[User])

update method:

def updateUserStatistic(
  orgId: Long, 
  newEvents: Iterator[User], 
  oldState: GroupState[UserStatistic]): UserStatistic = {
    var state: UserStatistic = if (oldState.exists) oldState.get else UserStatistic(orgId, 0L, 0L, Seq.empty)
    for (event <- newEvents) {
    //business logic like checking if userId in this organization is of certain type and then accordingly update the known or unknown attribute for that particular user.  

The problem gets worse when I have to execute this on Driver-Executor model as I am expecting 1-10 million users in every organization which could mean these many states on a single executor(correct me if I am wrong in understanding this.)

Possible solutions that failed:

  1. grouping by key as User Id - because then I am unable to get all userIds for a given orgId as these GroupStates are put in aggregation key, value pair and here, it is UserId. so for every new UserId, a new state is created, even if it belongs to same organization.

Any help or suggestions are appreciated.


  • Your state keeps increasing in size because in the current implementation no key/state pair will ever be removed from the GroupState.

    To mitigate exactly the problem you are facing (infinite increasing state) the mapGroupsWithState method allows you to use a Timeout. You can choose between two types of timeouts:

    • Processing-Time timeouts using GroupStateTimeout.ProcessingTimeTimeout with GroupState.setTimeoutDuration() , or
    • Event-Time timeouts using GroupStateTimeout.EventTimeTimeout with GroupState.setTimeoutTimestamp().

    Note the difference between them is a duration-based timeout and the more flexible time-based timeout.

    In the ScalaDocs of the trait GroupState you will find a nice template on how to use timeouts in your mapping function:

    def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = {
      if (state.hasTimedOut) {                // If called when timing out, remove the state
      } else if (state.exists) {              // If state exists, use it for processing
        val existingState = state.get         // Get the existing state
        val shouldRemove = ...                // Decide whether to remove the state
        if (shouldRemove) {
          state.remove()                      // Remove the state
        } else {
          val newState = ...
          state.update(newState)              // Set the new state
          state.setTimeoutDuration("1 hour")  // Set the timeout
      } else {
        val initialState = ...
        state.update(initialState)            // Set the initial state
        state.setTimeoutDuration("1 hour")    // Set the timeout
      // return something