Search code examples
scalaapache-sparkinheritanceapache-spark-datasetcase-class

Scala DataSet with case class inheritance


I'd like to be able to store different related types in a Spark DataFrame but work with strongly typed case classes via a DataSet. E.g. say I have a Base trait and two case classes A and B that extend the trait:

trait Base {
  def name: String
}

case class A(name: String, number: Int) extends Base

case class B(name: String, text: String) extends Base

I'd like to create a val lb = List[Base](A("Alice", 20), B("Bob", "Foo")) and then create a DataFrame via lb.toDS(). Not surprisingly, this doesn't work as there is no encoder for the trait for it's different extended classes.

I could manually create a case class representing a structure that can hold information for both A and B:

case class Struct(typ: String, name: String, number: Option[Int] = None, text: Option[String] = None)

And I could add some functions to create a Struct from an instance of a Base trait and vice vera:

trait Base {
  def name: String

  def asStruct: Struct = {
    this match {
      case A(name, number) => Struct("A", name, number = Some(number))
      case B(name, text) => Struct("B", name, text = Some(text))
    }
  }
}

case class Struct(typ: String, name: String, number: Option[Int] = None, text: Option[String] = None) {
  def asBase: Base = {
    this match {
      case Struct("A", name, Some(number), None) => A(name, number)
      case Struct("B", name, None, Some(text)) => B(name, text)
      case _ => throw new Exception(f"Invalid Base structure {s}")
    }
  }
}

Then I can create my DataFrame as follows:

    val a = A("Alice", 32)
    val b = B("Bob", "foo")

    val ls = List[Struct](a.asStruct, b.asStruct)

    val sparkSession = spark
    import sparkSession.implicits._

    val df = ls.toDS()

    df.show()

+---+-----+------+----+
|typ| name|number|text|
+---+-----+------+----+
|  A|Alice|    32|NULL|
|  B|  Bob|  NULL| foo|
+---+-----+------+----+

I can work with this approach but I wondered if it is possible to write an encoder that automatically treats a Base class as a Struct using the asStruct method written above?


Solution

  • frameless injections:

    // the injector
    implicit val baseHolder: Injection[Base, Struct] = new Injection[Base, Struct] with Serializable {
      def invert(a: Struct): Base =
        a match {
          case Struct("A", name, Some(number), None) => A(name, number)
          case Struct("B", name, None, Some(text)) => B(name, text)
          case _ => throw new Exception(f"Invalid Base structure {s}")
        }
    
      def apply(b: Base): Struct =
        b match {
          case A(name, number) => Struct("A", name, number = Some(number))
          case B(name, text) => Struct("B", name, text = Some(text))
        }
    }
    import frameless._
    implicit val enc = TypedExpressionEncoder[Base]
    
    import sparkSession.implicits._
    val lb: Dataset[Base] = Seq[Base](A("Alice", 20), B("Bob", "Foo")).toDS
    lb.show
    

    yields:

    +---+-----+------+----+
    |typ| name|number|text|
    +---+-----+------+----+
    |  A|Alice|    20|null|
    |  B|  Bob|  null| Foo|
    +---+-----+------+----+
    

    but are they really A's and B's?

    lb.collect().foreach{ b =>
      println(b.getClass.getName)
    }
    

    will show A and B.

    The parameter order on injections is important, the left one is the one we create the injection for and the 2nd parameter is the type we store it as in the dataframe.