Search code examples
scalaapache-sparktraitsextendcase-class

In scala how do you enforce the encoder to operate on the type when gets inserted into a generic function that only enforces certain traits?


I have a function called createTimeLineDS that takes another function as input and places that function in an internal dataset map method. createTimeLineDS only enforces traits on the input function type signature while Map requires that the function returns something of trait Encoder.

For some reason when I put a function that returns a case class into this function it throws an error:

    Unable to find encoder for type TIMELINE. An implicit Encoder[TIMELINE] is needed to store TIMELINE instances in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._  Support for serializing other types will be added in future releases.
    [error]       .map({ case ((event, team), user) =>  

convertEventToTimeLineFunction(event, team, user)})

The code is below, I have all the traits and the case classes defined. The thing that has something wrong is the last function and the calling of that function generates the error above. I have the import sparkSession.implicits._ in place so I am not sure how to do this correctly.

traits, case classes, and the function thats used as a parameter:

trait Event {
  val teamId: String
  val actorId: String
}

trait TimeLine {
  val teamDomain: Option[String]
  val teamName: Option[String]
  val teamIsTest: Option[Boolean]
  val actorEmail: Option[String]
  val actorName: Option[String]
}  

case class JobEventTimeline(
                         jobId: String,
                         jobType: Option[String],
                         inPlanning: Option[Boolean],

                         teamId: String,
                         actorId: String,
                         adminActorId: Option[String],
                         sessionId: String,
                         clientSessionId: Option[String],
                         clientCreatedAt: Long,
                         seqId: Long,
                         isSideEffect: Option[Boolean],

                         opAction: String,
                         stepId: Option[String],
                         jobBaseStepId: Option[String],
                         fieldId: Option[String],

                         serverReceivedAt: Option[Long],

                         // "Enriched" data. Data is pulled in from other sources during stream processing
                         teamDomain: Option[String] = None,
                         teamName: Option[String] = None,
                         teamIsTest: Option[Boolean] = None,

                         actorEmail: Option[String] = None,
                         actorName: Option[String] = None
                       ) extends TimeLine


def createJobEventTimeLine(jobEvent: CaseClassJobEvent, team: Team, user: User): JobEventTimeline = {
    JobEventTimeline(
      jobEvent.jobId,
      jobEvent.jobType,
      jobEvent.inPlanning,
      jobEvent.teamId,
      jobEvent.actorId,
      jobEvent.adminActorId,
      jobEvent.sessionId,
      jobEvent.clientSessionId,
      jobEvent.clientCreatedAt,
      jobEvent.seqId,
      jobEvent.isSideEffect,
      jobEvent.opAction,
      jobEvent.stepId,
      jobEvent.jobBaseStepId,
      jobEvent.fieldId,
      jobEvent.serverReceivedAt,
      Some(team.domain),
      Some(team.name),
      Some(team.is_test),
      Some(user.email),
      Some(user.name)
    )
  }

The problem function and the function call:

def createTimeLineDS[EVENT <: Event with Serializable, TIMELINE <: TimeLine]

  (convertEventToTimeLineFunction: (EVENT, Team, User) => TIMELINE)
  (sparkSession: SparkSession)
  (jobEventDS: Dataset[EVENT]): Dataset[TIMELINE] = {
    import sparkSession.implicits._
    val teamDS = FuncUtils.createDSFromPostgresql[Team](sparkSession)
    val userDS = FuncUtils.createDSFromPostgresql[User](sparkSession)
    jobEventDS
      .joinWith(teamDS, jobEventDS("teamId") === teamDS("id"), "left_outer")
      .joinWith(userDS, $"_1.actorId" === userDS("id"), "left_outer")
      .map({ case ((event, team), user) =>  convertEventToTimeLineFunction(event, team, user)})

Function call:

val jobEventTimeLine = FuncUtils.createTimeLineDS(JobEventTimeline.createJobEventTimeLine)(sparkSession)(jobEventDS)

Solution

  • The simplest solution would be to do this instead:

    def createTimeLineDS[EVENT <: Event, TIMELINE <: TimeLine : Encoder](...)
    

    You, probably wouldn't need the sparkSession parameter, as well as the import sparkSession.implicits._ line too.
    (but you may need more changes, keep reading).

    So, the problem is that the map method on a Dataset needs an implicit Encoder for the output type. Thus, what you are doing with that funny syntax (called context bound) is to say that your method also requires such implicit, thus the compiler will be happy as long as the caller of your method provides it (usually trough a import spark.implicits._ somewhere before).

    For more information about implicits, where does the compiler search for them & why do you need an encoder, please read the linked articles.


    Now, after you have read all that, I would expect what was the problem and how to fix it.
    But probably, you would still need the explicit import sparkSession.implicits._ on your method. That is probably because FuncUtils.createDSFromPostgresql[Team](sparkSession) does the same, but you now know how to refactor it.

    Also, since Team & User are concrete classes that you control, you may add something like this to their companion objects, so you do not need to ask for their encoders, because they will always be in the implicit scope.

    object Team {
      // https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.Encoders$@product[T%3C:Product](implicitevidence$5:reflect.runtime.universe.TypeTag[T]):org.apache.spark.sql.Encoder[T]
      implicit final val TeamEncoder: Encoder[Team] = Encoders.product
    }