Search code examples
scalaserializationscala-3

Can I use mirrors to synthesize match statements for ADTs in Scala 3?


I'd like to use Mirrors or some other technology to serialize ADTs.

My concrete use case is that I'm serializing messages over a channel. I can model the messages with case classes; that's easy enough. That gives me the following code:

sealed trait Message
final case class A(a: Int) extends Message
final case class B(b: Int, s: String) extends Message

def serializeMessage(m: Message) =
  Tuple.fromProductTyped(m).toList.map(_.serialize) // doesn't work because `m` is a Sum

type Primitive = Int | String

extension (p: Primitive)
    def serialize = p match {
    case i: Int =>  s"an Int: $i"
    case s: String => s"an String: $s"
  }

As far as I can see, I have two problems:

  • How can I guarantee at type level that all messsage case classes only include fields that have a serialize method available?
  • How do I convert an m: Message to a generic tuple of "serializables" that I can act on?

I could use match. The core logic is then:

def serializeMessage(m: Message) = m match {
  case a: A => Tuple.fromProductTyped(a).toList.map(_.serialize)
  case b: B => Tuple.fromProductTyped(b).toList.map(_.serialize)
}

This compiles. Unfortunately, my API has 50+ messages, and I might also want to support usecases other than serialization, so I'd like to automate the derivation. It's perfectly mechanical and very repetitive, so I feel like it "should" be doable.


Solution

  • You can automate the pattern matching with a macro

    import scala.quoted.{Expr, Quotes, quotes, Type}
    import scala.deriving.Mirror
    
    inline def serializeMessage(m: Message): List[String] = ${serializeMessageImpl('m)}
    
    def serializeMessageImpl(m: Expr[Message])(using Quotes): Expr[List[String]] = {
      import quotes.reflect.*
    
      val caseDefs = TypeRepr.of[Message].typeSymbol.children.map(symb => {
    
        val typeTree = TypeTree.ref(symb)
        val typeRepr = typeTree.tpe
    
        val bind = Symbol.newBind(Symbol.spliceOwner, "x", Flags.EmptyFlags, typeRepr)
        val ref = Ref(bind)
    
        typeRepr.asType match {
          case '[a0] =>
            '{tag[a0]} match {
              case '{
                type a <: Product
                tag[`a`]
              } => {
                val mirror = Expr.summon[Mirror.ProductOf[a]].getOrElse(
                  report.errorAndAbort(s"Can't find Mirror.ProductOf[${Type.show[a]}]")
                )
    
                CaseDef(
                  Bind(bind, Typed(ref, typeTree)),
                  None,
                  '{Tuple.fromProductTyped(${ref.asExprOf[a]})(using $mirror).toList.asInstanceOf[List[Primitive]].map(_.serialize)}.asTerm
                )
              }
    
            }
        }
      })
    
      Match(m.asTerm, caseDefs).asExprOf[List[String]]
    }
    
    def tag[A] = ???
    
    serializeMessage(B(1, "abc")) // List(an Int: 1, an String: abc)
    
    //scalac: m$proxy1 match {
    //  case x @ x =>      // case x: A =>
    //    scala.Tuple.fromProductTyped[Macro.A](x)(Macro.A.$asInstanceOf$[scala.deriving.Mirror.Product {
    //      type MirroredMonoType >: Macro.A <: Macro.A
    //      type MirroredType >: Macro.A <: Macro.A
    //      type MirroredLabel >: "A" <: "A"
    //      type MirroredElemTypes >: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] <: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple]
    //      type MirroredElemLabels >: scala.*:["a", scala.Tuple$package.EmptyTuple] <: scala.*:["a", scala.Tuple$package.EmptyTuple]
    //    }]).toList.asInstanceOf[scala.List[Macro.Primitive]].map[java.lang.String](((_$1: Macro.Primitive) => Macro.serialize(_$1)))
    //  case x @ `x₂` =>   // case x: B =>
    //    scala.Tuple.fromProductTyped[Macro.B](`x₂`)(Macro.B.$asInstanceOf$[scala.deriving.Mirror.Product {
    //      type MirroredMonoType >: Macro.B <: Macro.B
    //      type MirroredType >: Macro.B <: Macro.B
    //      type MirroredLabel >: "B" <: "B"
    //      type MirroredElemTypes >: scala.*:[scala.Int, scala.*:[scala.Predef.String, scala.Tuple$package.EmptyTuple]] <: scala.*:[scala.Int, scala.*:[scala.Predef.String, scala.Tuple$package.EmptyTuple]]
    //      type MirroredElemLabels >: scala.*:["b", scala.*:["s", scala.Tuple$package.EmptyTuple]] <: scala.*:["b", scala.*:["s", scala.Tuple$package.EmptyTuple]]
    //    }]).toList.asInstanceOf[scala.List[Macro.Primitive]].map[java.lang.String](((`_$1₂`: Macro.Primitive) => Macro.serialize(`_$1₂`)))
    //}
    

    Scala 3 collection partitioning with subtypes

    https://github.com/lampepfl/dotty/discussions/12472


    Alternatively you can introduce a type class and derive it (e.g. with Shapeless 3)

    libraryDependencies ++= Seq(
      "org.typelevel" %% "shapeless3-deriving" % "3.2.0",
      "org.typelevel" %% "shapeless3-typeable" % "3.2.0"
    )
    
    import shapeless3.deriving.K0
    import shapeless3.typeable.Typeable
    
    trait Serializer[T]:
      def serialize(t: T): String
    
    trait LowPrioritySerializer:
      given [T](using typeable: Typeable[T]): Serializer[T] with
        override def serialize(t: T): String = s"an ${typeable.describe}: $t"
    
    object Serializer extends LowPrioritySerializer:
      given prod[T](using inst: K0.ProductInstances[Serializer, T]): Serializer[T] with
        override def serialize(t: T): String = inst.foldRight[String](t)("")(
          [a] => (s: Serializer[a], x: a, acc: String) =>
            s.serialize(x) + (if acc.isEmpty then "" else ", ") + acc
        )
    
      given coprod[T](using inst: K0.CoproductInstances[Serializer, T]): Serializer[T] with
        override def serialize(t: T): String = inst.fold[String](t)(
          [a <: T] => (s: Serializer[a], x: a) => s.serialize(x)
        )
    
    extension [T: Serializer](t: T)
      def serialize = summon[Serializer[T]].serialize(t)
    
    A(1).serialize // an Int: 1
    B(1, "abc").serialize // an Int: 1, an String: abc
    (A(1): Message).serialize // an Int: 1
    (B(1, "abc"): Message).serialize // an Int: 1, an String: abc
    

    Actually, under the hood Shapeless 3 uses scala.deriving.Mirror.

    How to access parameter list of case class in a dotty macro

    Using K0.ProductInstances in shapeless3