Search code examples
apache-sparkaggregatespark-streamingspark-structured-streaming

Spark Stateful Structured Streaming: State getting too big in mapGroupsWithState


I am trying to use mapGroupsWithState method for stateful structured streaming for my incoming stream of data. But the problem that I face is that the key I am choosing for groupByKey makes my state too big too fast. The obvious way out would be to change the key but the business logic I wish to apply in update method, requires the key to exactly same as I have it right now OR if it is possible, access GroupState for all keys.

For example, I have a stream of data coming in from various Organizations and typically an organization contains userId, personId etc. Please see the code below:

val stream: Dataset[User] = dataFrame.as[User]
val noTimeout = GroupStateTimeout.NoTimeout
val statisticStream = stream
    .groupByKey(key => key.orgId)
    .mapGroupsWithState(noTimeout)(updateUserStatistic)

val df = statisticStream.toDF()

val query = df
    .writeStream
    .outputMode(Update())
    .option("checkpointLocation", s"$checkpointLocation/$name")
    .foreach(new UserCountWriter(spark.sparkContext.getConf))
    .outputMode(Update())
    .queryName(name)
    .trigger(Trigger.ProcessingTime(Duration.apply("10 seconds")))

case classes:

case class User(
  orgId: Long,
  profileId: Long,
  userId: Long)

case class UserStatistic(
  orgId: Long,
  known: Long,
  uknown: Long,
  userSeq: Seq[User])

update method:

def updateUserStatistic(
  orgId: Long, 
  newEvents: Iterator[User], 
  oldState: GroupState[UserStatistic]): UserStatistic = {
    var state: UserStatistic = if (oldState.exists) oldState.get else UserStatistic(orgId, 0L, 0L, Seq.empty)
    for (event <- newEvents) {
    //business logic like checking if userId in this organization is of certain type and then accordingly update the known or unknown attribute for that particular user.  
    oldState.update(state)
    state
  }

The problem gets worse when I have to execute this on Driver-Executor model as I am expecting 1-10 million users in every organization which could mean these many states on a single executor(correct me if I am wrong in understanding this.)

Possible solutions that failed:

  1. grouping by key as User Id - because then I am unable to get all userIds for a given orgId as these GroupStates are put in aggregation key, value pair and here, it is UserId. so for every new UserId, a new state is created, even if it belongs to same organization.

Any help or suggestions are appreciated.


Solution

  • Your state keeps increasing in size because in the current implementation no key/state pair will ever be removed from the GroupState.

    To mitigate exactly the problem you are facing (infinite increasing state) the mapGroupsWithState method allows you to use a Timeout. You can choose between two types of timeouts:

    • Processing-Time timeouts using GroupStateTimeout.ProcessingTimeTimeout with GroupState.setTimeoutDuration() , or
    • Event-Time timeouts using GroupStateTimeout.EventTimeTimeout with GroupState.setTimeoutTimestamp().

    Note the difference between them is a duration-based timeout and the more flexible time-based timeout.

    In the ScalaDocs of the trait GroupState you will find a nice template on how to use timeouts in your mapping function:

    def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = {
    
      if (state.hasTimedOut) {                // If called when timing out, remove the state
        state.remove()
    
      } else if (state.exists) {              // If state exists, use it for processing
        val existingState = state.get         // Get the existing state
        val shouldRemove = ...                // Decide whether to remove the state
        if (shouldRemove) {
          state.remove()                      // Remove the state
    
        } else {
          val newState = ...
          state.update(newState)              // Set the new state
          state.setTimeoutDuration("1 hour")  // Set the timeout
        }
    
      } else {
        val initialState = ...
        state.update(initialState)            // Set the initial state
        state.setTimeoutDuration("1 hour")    // Set the timeout
      }
      ...
      // return something
    }