Search code examples
scalaapache-sparkinheritanceapache-spark-datasetframeless

Spark UDF doesn't get decoded dataset class using org.typelevel.frameless encoder injection


I asked a question about how to implement a spark dataset with case class inheritance (see Spark Dataset with case class inheritance) and got a helpful reply to use org.typelevel.frameless to inject an encoder.

In the original question I had a Base trait and two case classes A and B and wanted to create a Dataset as a sequence of Base objects. The frameless injector worked nicely so I extended my test to have a third case class C that had two Base properties:

trait Base {
  def name: String
}

case class A(name: String, number: Int) extends Base

case class B(name: String, text: String) extends Base

case class C(x: Base, y: Base)

I did this because I wanted to test a UDF that could be passed a Base column and would get the decoded case class. To use frameless I set up some implicit variables for the injector and encoders:

object dataSetImplicits {
    implicit val baseHolder: Injection[Base, Struct] = new Injection[Base, Struct] with Serializable {
      def invert(a: Struct): Base =
        a match {
          case Struct("A", name, Some(number), None) => A(name, number)
          case Struct("B", name, None, Some(text)) => B(name, text)
          case _ => throw new Exception(f"Invalid Base structure {s}")
        }

      def apply(b: Base): Struct =
        b match {
          case A(name, number) => Struct("A", name, number = Some(number))
          case B(name, text) => Struct("B", name, text = Some(text))
        }
    }
    implicit val baseEnc  = TypedExpressionEncoder[Base]
    implicit val cEnc = TypedExpressionEncoder[C]
}

And then ran my local scalatest

import sparkSession.implicits._
import dataSetImplicits._

val a = A("Alice", 32)
val b = B("Bob", "foo")
val lb = Seq(C(a, b))

val ds: Dataset[C] = lb.toDS()

def processBase(b: Base): String = {
  b match {
    case A(name, number) => f"$name has number $number"
    case B(name, text) => f"$name has text $text"
    case _ => "Don't know"
  }
}
val processBaseUdf = udf(processBase(_))

var df = ds.withColumn("xResult", processBaseUdf(ds("x")))
df = df.withColumn("yResult", processBaseUdf(ds("y")))
df.show()

The expected result was

+--------------------+-------------------+-------------------+----------------+
|                   x|                  y|            xResult|         yResult|
+--------------------+-------------------+-------------------+----------------+
|{A, Alice, 32, NULL}|{B, Bob, NULL, foo}|Alice has number 32|Bob has text foo|
+--------------------+-------------------+-------------------+----------------+

But the test failed because the UDF is passed the Struct type and not the Base object.

If I collect the rows from the data set and show the type of the x and y properties they correctly show the decoded Base types:

ds.collect().foreach{x => println(f"${x.getClass.getName} has x type ${x.x.getClass.getName} and y has type ${x.y.getClass.getName}")}

returns

C has x type A and y has type B

So it seems as if the encoding/decoding is working fine except for the use of UDFs.

Can anyone advise if this is the case or I'm setting up the encoders wrong?

Thanks,

David


Solution

  • fun stuff and really pushes the rougher edges of frameless use for normal Datasets.

    The default spark udf's use reflection to create encoders and won't use an implicit encoder at all, despite the internal private SparkUserDefinedFunction being able to use them.

    Frameless udf's work but are intended to use with TypedDataset and TypedColumn so some workarounds are needed.

    To get the Dataset[C] you need:

    implicit val cenc = TypedExpressionEncoder[C]
    

    defined before calling toDS.

    Then, as you can't use the Spark scala udf directly, to work around frameless you need to define your own udf wrapper:

    def udf[A: TypedEncoder, R: TypedEncoder](f: A => R): Column => Column = {
      u =>
        val scalaUdf = new frameless.functions.FramelessUdf(
          f,
          encoders = Seq(TypedEncoder[A]),
          children = List(TypedEncoder[A].fromCatalyst(u.expr)),
          TypedEncoder[R]
        )
        new Column(scalaUdf)
    }
    

    Using this version of udf (obviously only one param) takes in the TypedEncoders (not the TypedExpressionEncoder used by toDs) for the input and output parameters and wraps the column expression (x or y AttributeReferences) with the appropriate construction logic to create Base.

    +--------------------+-------------------+-------------------+----------------+
    |                   x|                  y|            xResult|         yResult|
    +--------------------+-------------------+-------------------+----------------+
    |{A, Alice, 32, null}|{B, Bob, null, foo}|Alice has number 32|Bob has text foo|
    +--------------------+-------------------+-------------------+----------------+
    

    Whilst this works fine for most runtimes there, aside from the usability issue, is an actual runtime issue with the latest Databricks 14.3 runtime which does not tolerate frameless' compiled by default UDFs (represented by issue #803 and pull #806). This is one of the issues that is tested via sparkutils frameless.