Scala 3 constructor inheritance with macros

Every class implementing a trait must declare a constructor that sets trait's fields:

sealed trait WithPayload:
    def description: String
    def payload1: Int
    def payload2: Long

// All WithPayload's fields have to be listed.
final case class Foo(
    override val payload1: Int,
    override val payload2: Long
) extends WithPayload:
    override def description = "foo"

// All WithPayload's fields have to be listed again.
final case class Bar(
    override val payload1: Int,
    override val payload2: Long
) extends WithPayload:
    override def description = "bar"

Is there a way to get rid of repeated constructor declarations with a macro, kinda like

    override val payload1: Int, \
    override val payload2: Long \
) extends WithPayload

and then:

final case class Foo EXTENDS_WITH_PAYLOAD:
    override def description = "foo"

final case class Bar EXTENDS_WITH_PAYLOAD:
    override def description = "bar"


  • import scala.annotation.{StaticAnnotation, compileTimeOnly}
    import scala.language.experimental.macros
    import scala.reflect.macros.blackbox
    @compileTimeOnly("enable macro annotations")
    class extendsWithPayload extends StaticAnnotation {
      def macroTransform(annottees: Any*): Any = macro ExtendsWithPayloadMacros.macroTransformImpl
    object ExtendsWithPayloadMacros {
      def macroTransformImpl(c: blackbox.Context)(annottees: c.Tree*): c.Tree = {
        import c.universe._
        annottees match {
          case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: tail =>
            val parents1 = parents :+ tq"WithPayload"
            val newParams = Seq(q"override val payload1: Int", q"override val payload2: Long")
            val paramss1 = paramss match {
              case Nil => Seq(newParams)
              case params :: paramss1 => (params ++ newParams) :: paramss1
              $mods class $tpname[..$tparams] $ctorMods(...$paramss1) extends { ..$earlydefns } with ..$parents1 { $self =>
    sealed trait WithPayload {
      def description: String
      def payload1: Int
      def payload2: Long
    final case class Foo() {
      override def description = "foo"
    final case class Bar() {
      override def description = "bar"
    //final case class Foo extends WithPayload with scala.Product with scala.Serializable {
    //    override <caseaccessor> <paramaccessor> val payload1: Int = _;
    //    override <caseaccessor> <paramaccessor> val payload2: Long = _;
    //    def <init>(payload1: Int, payload2: Long) = {
    //      super.<init>();
    //      ()
    //    };
    //    override def description = "foo"
    //  };
    //  ()
    //final case class Bar extends WithPayload with scala.Product with scala.Serializable {
    //    override <caseaccessor> <paramaccessor> val payload1: Int = _;
    //    override <caseaccessor> <paramaccessor> val payload2: Long = _;
    //    def <init>(payload1: Int, payload2: Long) = {
    //      super.<init>();
    //      ()
    //    };
    //    override def description = "bar"
    //  };
    //  ()
    • In Scala 3 there will appear macro annotations too but new definitions can only be seen inside macro expansion

    scalaVersion := "3.3.0-RC4"
    import scala.annotation.{MacroAnnotation, experimental}
    import scala.collection.mutable
    import scala.quoted.*
    /*sealed*/ trait WithPayload:
      def description: String
      def payload1: Int
      def payload2: Long
    class extendsWithPayload extends MacroAnnotation:
      def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
        import quotes.reflect.*
        tree match
          case ClassDef(className, ctr, parents, self, body) =>
            val res = List(ClassDef.copy(tree)(className, ctr, parents :+ TypeTree.of[WithPayload], self, body))
    @extendsWithPayload @experimental
    final case class Foo():
      override def description = "foo" // method description overrides nothing
    @extendsWithPayload @experimental
    final case class Bar():
      override def description = "bar" // method description overrides nothing
    summon[Foo <:< WithPayload] // Cannot prove that Foo <:< WithPayload
    val foo = new Foo()
    foo: WithPayload // Found: (foo: Foo), Required: WithPayload
    //List(@scala.annotation.experimental @Macros.extendsWithPayload final case class Foo() extends Macros.WithPayload {
    //  override def hashCode(): scala.Int = scala.runtime.ScalaRunTime._hashCode(Foo.this)
    //  override def equals(x$0: scala.Any): scala.Boolean = Foo.this.eq(x$0.$asInstanceOf$[java.lang.Object]).||(x$0 match {
    //    case x$0: App.Foo @scala.unchecked =>
    //      true
    //    case _ =>
    //      false
    //  })
    //  override def toString(): java.lang.String = scala.runtime.ScalaRunTime._toString(Foo.this)
    //  override def canEqual(that: scala.Any): scala.Boolean = that.isInstanceOf[App.Foo @scala.unchecked]
    //  override def productArity: scala.Int = 0
    //  override def productPrefix: scala.Predef.String = "Foo"
    //  override def productElement(n: scala.Int): scala.Any = n match {
    //    case _ =>
    //      throw new java.lang.IndexOutOfBoundsException(n.toString())
    //  }
    //  override def description: java.lang.String = "foo"
    //List(@scala.annotation.experimental @Macros.extendsWithPayload final case class Bar() extends Macros.WithPayload {
    //  override def hashCode(): scala.Int = scala.runtime.ScalaRunTime._hashCode(Bar.this)
    //  override def equals(x$0: scala.Any): scala.Boolean = Bar.this.eq(x$0.$asInstanceOf$[java.lang.Object]).||(x$0 match {
    //    case x$0: App.Bar @scala.unchecked =>
    //      true
    //    case _ =>
    //      false
    //  })
    //  override def toString(): java.lang.String = scala.runtime.ScalaRunTime._toString(Bar.this)
    //  override def canEqual(that: scala.Any): scala.Boolean = that.isInstanceOf[App.Bar @scala.unchecked]
    //  override def productArity: scala.Int = 0
    //  override def productPrefix: scala.Predef.String = "Bar"
    //  override def productElement(n: scala.Int): scala.Any = n match {
    //    case _ =>
    //      throw new java.lang.IndexOutOfBoundsException(n.toString())
    //  }
    //  override def description: java.lang.String = "bar"
    • For example you can use code generation with Scalameta

    libraryDependencies ++= Seq(
      "org.scalameta" %% "scalameta" % "4.7.7"


    ThisBuild / scalaVersion := "3.2.2"
    lazy val common = project
    lazy val before = project
    lazy val after = project
        Compile / unmanagedSourceDirectories += (Compile / sourceManaged).value
    lazy val transform = taskKey[Unit]("Transform sources")
    transform := {
      val inputDir  = (before / Compile / scalaSource).value
      val outputDir = (after / Compile / sourceManaged).value
      Generator.gen(inputDir, outputDir)


    import sbt.*
    object Generator {
      val ALL: Seq[String] = Seq()
      def isAll(filesToTransform: Seq[String]): Boolean = filesToTransform.isEmpty
      def gen(
               inputDir: File,
               outputDir: File,
               filesToTransform: Seq[String] = ALL,
             ): Unit = {
        val finder: PathFinder = inputDir ** "*.scala"
        val scalametaTransformer = new AnnotationProcessor()
        for (inputFile <- finder.get) yield {
          val inputFileName =
          val inputStr =
          val transform: String => String =
            if (isAll(filesToTransform) || filesToTransform.contains(inputFileName))
              (scalametaTransformer(_: String))
            else identity
          val outputStr = transform(inputStr)
          val outputFile = outputDir / inputFile.relativeTo(inputDir).get.toString
          IO.write(outputFile, outputStr)


    import scala.meta.*
    class AnnotationProcessor extends TreeTransformer {
      val isExtendsWithPayload: Mod => Boolean = { case mod"@extendsWithPayload" => true; case _ => false }
      override def apply(tree: Tree): Tree = {
        val tree1 = tree match {
          case q"..$mods class $tname[..$tparams] ..$ctorMods (...$paramss) $template" if mods.exists(isExtendsWithPayload) =>
            val mods1 = mods.filterNot(isExtendsWithPayload)
            template match {
              case template"{ ..$earlyStats } with ..$inits { $self => ..$stats }" =>
                val inits1 = inits :+ init"WithPayload"
                val template1 = template"{ ..$earlyStats } with ..$inits1 { $self => ..$stats }"
                val newParams = List(param"override val payload1: Int", param"override val payload2: Long")
                val paramss1: List[Term.ParamClause] = paramss match {
                  case Nil => List(newParams)
                  case params :: paramss1 => (params ++ newParams) :: paramss1
                q"..$mods1 class $tname[..$tparams] ..$ctorMods (...$paramss1) $template1"
          case _ => tree


    trait StringTransformer {
      def apply(str: String): String


    import scala.meta.*
    trait TreeTransformer extends Transformer with StringTransformer {
      override def apply(str: String): String = {
        val origTree = dialects.Scala3(str).parse[Source].get
        val newTree  = apply(origTree)


    import scala.annotation.StaticAnnotation
    class extendsWithPayload extends StaticAnnotation 


    sealed trait WithPayload:
      def description: String
      def payload1: Int
      def payload2: Long
    final case class Foo():
      override def description = "foo"
      class Nested():
        override def description = "nested"
    final case class Bar():
      override def description = "bar"
    final case class Baz()

    Execute sbt after/clean transform


    sealed trait WithPayload {
      def description: String
      def payload1: Int
      def payload2: Long
    final case class Foo(override val payload1: Int, override val payload2: Long) extends WithPayload {
      override def description = "foo"
      class Nested(override val payload1: Int, override val payload2: Long) extends WithPayload { override def description = "nested" }
    final case class Bar(override val payload1: Int, override val payload2: Long) extends WithPayload { override def description = "bar" }
    final case class Baz()
    • Or you can use C++ preprocessor (#define) as you wrote.

    gcc -xc App.scala -E -P -o App1.scala


    #define EXTENDS_WITH_PAYLOAD ( \
        override val payload1: Int, \
        override val payload2: Long \
    ) extends WithPayload
    sealed trait WithPayload:
        def description: String
        def payload1: Int
        def payload2: Long
    final case class Foo EXTENDS_WITH_PAYLOAD:
        override def description = "foo"
    final case class Bar EXTENDS_WITH_PAYLOAD:
        override def description = "bar"


    sealed trait WithPayload:
        def description: String
        def payload1: Int
        def payload2: Long
    final case class Foo ( override val payload1: Int, override val payload2: Long ) extends WithPayload:
        override def description = "foo"
    final case class Bar ( override val payload1: Int, override val payload2: Long ) extends WithPayload:
        override def description = "bar"
    • Or you can define Scalafix re-writing rule.

