Search code examples
scalareflectionmetaprogrammingscala-macrosdotty

How to access parameter list of case class in a dotty macro


I am trying to learn meta-programming in dotty. Specifically compile time code generation. I thought learning by building something would be a good approach. So I decided to make a CSV parser which will parse lines into case classes. I want to use dotty macros to generate decoders

trait Decoder[T]{
  def decode(str:String):Either[ParseError, T]
}

object Decoder {
  inline given stringDec as Decoder[String] = new Decoder[String] {
    override def decode(str: String): Either[ParseError, String] = Right(str)
  }

  inline given intDec as Decoder[Int] = new Decoder[Int] {
    override def decode(str: String): Either[ParseError, Int] =
      str.toIntOption.toRight(ParseError(str, "value is not valid Int"))
  }
  
  inline def forType[T]:Decoder[T] = ${make[T]}

  def make[T:Type](using qctx: QuoteContext):Expr[Decoder[T]] = ???
}

I have provided basic decoders for Int & String, now I looking for guidance for def make[T:Type] method. How to iterate parameter list of a case class T inside this method? Are there any recommended ways or patterns to do this?


Solution

  • Using standard type class derivation in Dotty

    import scala.deriving.Mirror
    
    case class ParseError(str: String, msg: String)
    
    trait Decoder[T]{
      def decode(str:String): Either[ParseError, T]
    }
    
    object Decoder {
      given Decoder[String] with {
        override def decode(str: String): Either[ParseError, String] = Right(str)
      }
    
      given Decoder[Int] with {
        override def decode(str: String): Either[ParseError, Int] =
          str.toIntOption.toRight(ParseError(str, "value is not valid Int"))
      }
    
      inline def derived[T](using m: Mirror.Of[T]): Decoder[T] = {
        val elemInstances = summonAll[m.MirroredElemTypes]
        inline m match {
          case p: Mirror.ProductOf[T] => productDecoder(p, elemInstances)
          case s: Mirror.SumOf[T]     => ???
        }
      }
    
      inline def summonAll[T <: Tuple]: List[Decoder[?]] =
        compiletime.summonAll[Tuple.Map[T, Decoder]].toList.asInstanceOf[List[Decoder[?]]]
    
      def productDecoder[T](p: Mirror.ProductOf[T], elems: List[Decoder[?]]): Decoder[T] =
        new Decoder[T] {
          def decode(str: String): Either[ParseError, T] = {
            val strs = str.split(',')
            if (strs.isEmpty) Left(ParseError(str, "nothing to split"))
            else elems.zip(strs)
              .traverse(_.decode(_))
              .map(ts => p.fromProduct(Tuple.fromArray(ts.toArray)))
          }
        }
    
      extension [E,A,B](es: List[A])
        def traverse(f: A => Either[E, B]): Either[E, List[B]] =
          es.foldRight[Either[E, List[B]]](Right(Nil))((h, tRes) => map2(f(h), tRes)(_ :: _))
    
      def map2[E, A, B, C](a: Either[E, A], b: Either[E, B])(f: (A, B) => C): Either[E, C] =
        for { a1 <- a; b1 <- b } yield f(a1,b1)
    }
    
    case class A(i: Int, s: String) derives Decoder
    
    println(summon[Decoder[A]].decode("10,abc"))//Right(A(10,abc))
    println(summon[Decoder[A]].decode("xxx,abc"))//Left(ParseError(xxx,value is not valid Int))
    println(summon[Decoder[A]].decode(",,"))//Left(ParseError(,,,nothing to split))
    

    Tested in 3.2.0.


    Using Shapeless-3

    import shapeless3.deriving.K0
    import shapeless3.typeable.Typeable
    
    case class ParseError(str: String, msg: String)
    
    trait Decoder[T]{
      def decode(str:String): Either[ParseError, T]
    }
    
    object Decoder {
      inline given stringDec: Decoder[String] = new Decoder[String] {
        override def decode(str: String): Either[ParseError, String] = Right(str)
      }
    
      inline given intDec: Decoder[Int] = new Decoder[Int] {
        override def decode(str: String): Either[ParseError, Int] =
          str.toIntOption.toRight(ParseError(str, "value is not valid Int"))
      }
    
      inline def derived[A](using gen: K0.Generic[A]): Decoder[A] =
        gen.derive(productDecoder, null)
    
      given productDecoder[T](using inst: K0.ProductInstances[Decoder, T], typeable: Typeable[T]): Decoder[T] = new Decoder[T] {
        def decode(str: String): Either[ParseError, T] = {
          type Acc = (List[String], Option[ParseError])
          inst.unfold[Acc](str.split(',').toList, None)([t] => (acc: Acc, dec: Decoder[t]) =>
            acc._1 match {
              case head :: tail => dec.decode(head) match {
                case Right(t) => ((tail, None), Some(t))
                case Left(e)  => ((Nil, Some(e)), None)
              }
              case Nil => (acc, None)
            }
          ) match {
            case ((_, Some(e)), None) => Left(e)
            case ((_, None), None)    => Left(ParseError(str, s"value is not valid ${typeable.describe}"))
            case (_, Some(t))         => Right(t)
          }
        }
      }
    }
    
    case class A(i: Int, s: String) derives Decoder
    
    @main def test = {
      println(summon[Decoder[A]].decode("10,abc")) //Right(A(10,abc))
      println(summon[Decoder[A]].decode("xxx,abc")) //Left(ParseError(xxx,value is not valid Int))
      println(summon[Decoder[A]].decode(",")) //Left(ParseError(,,value is not valid A))
    }
    

    build.sbt

    scalaVersion := "3.2.0"
    
    libraryDependencies += "org.typelevel" %% "shapeless3-deriving" % "3.2.0"
    libraryDependencies += "org.typelevel" %% "shapeless3-typeable" % "3.2.0"
    

    Using Dotty macros + TASTy reflection like in dotty-macro-examples/macroTypeclassDerivation (this approach is even more low-level than the one with scala.deriving.Mirror)

    import scala.quoted.*
    
    case class ParseError(str: String, msg: String)
    
    trait Decoder[T]{
      def decode(str: String): Either[ParseError, T]
    }
    
    object Decoder {
      inline given Decoder[String] with {
        override def decode(str: String): Either[ParseError, String] = Right(str)
      }
    
      inline given Decoder[Int] with {
        override def decode(str: String): Either[ParseError, Int] =
          str.toIntOption.toRight(ParseError(str, "value is not valid Int"))
      }
    
      inline def derived[T]: Decoder[T] = ${ derivedImpl[T] }
    
      def derivedImpl[T](using Quotes, Type[T]): Expr[Decoder[T]] = {
        import quotes.reflect.*
        val tpeSym = TypeRepr.of[T].typeSymbol
        if (tpeSym.flags.is(Flags.Case)) productDecoder[T]
        else if (tpeSym.flags.is(Flags.Trait & Flags.Sealed)) ???
        else sys.error(s"Unsupported combination of flags: ${tpeSym.flags.show}")
      }
    
      def productDecoder[T](using Quotes, Type[T]): Expr[Decoder[T]] = {
        import quotes.reflect.*
        val fields: List[Symbol]             = TypeRepr.of[T].typeSymbol.caseFields
        val fieldTypeTrees: List[TypeTree]   = fields.map(_.tree.asInstanceOf[ValDef].tpt)
        val decoderTerms: List[Term]         = fieldTypeTrees.map(lookupDecoderFor(_))
        val decoders: Expr[List[Decoder[_]]] = Expr.ofList(decoderTerms.map(_.asExprOf[Decoder[_]]))
    
        def mkT(fields: Expr[List[_]]): Expr[T] = {
          Apply(
            Select.unique(New(TypeTree.of[T]), "<init>"),
            fieldTypeTrees.zipWithIndex.map((fieldType, i) =>
              TypeApply(
                Select.unique(
                  Apply(
                    Select.unique(
                      fields.asTerm,
                      "apply"),
                    List(Literal(IntConstant(i)))
                  ), "asInstanceOf"),
                List(fieldType)
              )
            )
          ).asExprOf[T]
        }
    
        '{
          new Decoder[T]{
            override def decode(str: String): Either[ParseError, T] = {
              val strs = str.split(',').toList
              if (strs.isEmpty) Left(ParseError(str, "nothing to split"))
              else $decoders.zip(strs).traverse(_.decode(_)).map(fields =>
                ${mkT('fields)}
              )
            }
          }
        }
      }
    
      def lookupDecoderFor(using Quotes)(t: quotes.reflect.Tree): quotes.reflect.Term = {
        import quotes.reflect.*
        val tpe: TypeTree = Applied(TypeTree.of[Decoder], List(t))
        Implicits.search(tpe.tpe) match {
          case res: ImplicitSearchSuccess => res.tree
        }
      }
    
      extension [E,A,B](es: List[A]) {
        def traverse(f: A => Either[E, B]): Either[E, List[B]] =
          es.foldRight[Either[E, List[B]]](Right(Nil))((h, tRes) => map2(f(h), tRes)(_:: _))
      }
    
      def map2[E, A, B, C](a: Either[E, A], b: Either[E, B])(f: (A, B) => C): Either[E, C] =
        for { a1 <- a; b1 <- b } yield f(a1,b1)
    }
    
    case class A(i: Int, s: String) derives Decoder
    
    @main def test = {
      println(summon[Decoder[A]].decode("10,abc"))//Right(A(10,abc))
      println(summon[Decoder[A]].decode("xxx,abc"))//Left(ParseError(xxx,value is not valid Int))
      println(summon[Decoder[A]].decode(","))//Left(ParseError(,,nothing to split))
    }
    

    Tested in 3.2.0.


    We can implement Generic like in Scala 2/Shapeless 2

    Scala 3 collection partitioning with subtypes

    import scala.deriving.Mirror
    
    trait Generic[T] {
      type Repr
      def to(t: T): Repr
      def from(r: Repr): T
    }
    
    object Generic {
      type Aux[T, Repr0] = Generic[T] {type Repr = Repr0}
    
      def instance[T, Repr0](f: T => Repr0, g: Repr0 => T): Aux[T, Repr0] =
        new Generic[T] {
          override type Repr = Repr0
          override def to(t: T): Repr0 = f(t)
          override def from(r: Repr0): T = g(r)
        }
    
      object ops {
        extension[A] (a: A) {
          def toRepr(using g: Generic[A]): g.Repr = g.to(a)
        }
    
        extension[Repr] (a: Repr) {
          def to[A](using g: Generic.Aux[A, Repr]): A = g.from(a)
        }
      }
    
      given [T <: Product](using
        m: Mirror.ProductOf[T],
        m1: Mirror.ProductOf[m.MirroredElemTypes]
      ): Aux[T, m.MirroredElemTypes] = instance(
        m1.fromProduct(_),
        m.fromProduct(_)
      )
    }
    

    and derive the type class with Generic

    case class ParseError(str: String, msg: String)
    
    trait Decoder[T]{
      def decode(str:String): Either[ParseError, T]
    }
    
    object Decoder {
      given Decoder[String] with {
        override def decode(str: String): Either[ParseError, String] = Right(str)
      }
    
      given Decoder[Int] with {
        override def decode(str: String): Either[ParseError, Int] =
          str.toIntOption.toRight(ParseError(str, "value is not valid Int"))
      }
    
      given Decoder[EmptyTuple] with {
        override def decode(str: String): Either[ParseError, EmptyTuple] =
          Either.cond(str.isEmpty, EmptyTuple, ParseError(str, "not empty string"))
      }
    
      given [H, T <: Tuple](using hDecoder: Decoder[H], tDecoder: Decoder[T]): Decoder[H *: T] with {
        override def decode(str: String): Either[ParseError, H *: T] = for {
          h <- hDecoder.decode(str.takeWhile(_ != ','))
          t <- tDecoder.decode(str.dropWhile(_ != ',').stripPrefix(","))
        } yield h *: t
      }
    
      given [T](using gen: Generic[T], decoder: Decoder[gen.Repr]): Decoder[T] with {
        override def decode(str: String): Either[ParseError, T] = decoder.decode(str).map(gen.from)
      }
    }
    
    case class A(i: Int, s: String)
    
    println(summon[Decoder[A]].decode("10,abc"))//Right(A(10,abc))
    println(summon[Decoder[A]].decode("xxx,abc"))//Left(ParseError(xxx,value is not valid Int))
    println(summon[Decoder[A]].decode("10,abc,xxx"))//Left(ParseError(xxx,not empty string))
    println(summon[Decoder[A]].decode(",,"))//Left(ParseError(,value is not valid Int))
    

    Tested in 3.2.0.


    For comparison deriving type classes in Scala 2

    Use the lowest subtype in a typeclass?