Search code examples
scalaapache-sparkspark-streaming

What's the right way to "log and skip" validated transformations in spark-streaming


I have a spark-streaming application where I want to do some data transformations before my main operation, but the transformation involves some data validation.

When the validation fails, I want to log the failure cases, and then proceed on with the rest.

Currently, I have code like this:

def values: DStream[String] = ???
def validate(element: String): Either[String, MyCaseClass] = ???

val validationResults = values.map(validate)

validationResults.foreachRDD { rdd =>
  rdd.foreach {
    case Left(error) => logger.error(error)
    case _           =>
  }
}

val validatedValues: DStream[MyCaseClass] =
  validationResults.mapPartitions { partition =>
    partition.collect { case Right(record) => record }
  }

This currently works, but it feels like I'm doing something wrong.

Questions

As far as I understand, this will perform the validation twice, which is potentially wasteful.

  • Is it correct to use values.map(validation).persist() to solve that problem?
  • Even if I persist the values, it still iterates and pattern matches on all the elements twice. It feels like there should be some method I can use to solve this. On a regular scala collection, I might use some of the cats TraverseFilter api, or with fs2.Stream an evalMapFilter. What DStream api can I use for that? Maybe something with mapPartitions?

Solution

  • I would say that the best way to tackle this is to take advantage that the stdlib flatMap accepts Option

    def values: DStream[String] = ???
    def validate(element: String): Either[String, MyCaseClass] = ???
    
    val validatedValues: DStream[MyCaseClass] =
      values.map(validate).flatMap {
        case Left(error) =>
          logger.error(error)
          None
    
        case Right(record) =>
          Some(record)
      }
    

    You can also be a little bit more verbose using mapPartitions which should be a little bit more efficient.