Search code examples
scalatraitsscala-macrosscala-macro-paradise

Make Scala class extends trait/abstract class with macros


The problem:

I want to make the annotated class a subclass of another class with scala macro What I have:

Wrapper for fields:

class Field(fieldType: DbModelFieldType, fieldName: String) 

An abstract class (base class for all annotated classes):

abstract class DatabaseModel {
  def fields: Seq[Fields]
}

I have a case class:

Model(num: Int, sym: Char, descr: String)

and if annotate that class with @GetFromDB

@GetFromDB
Model(num: Int, sym: Char, descr: String)
case class Model(num: Int, sym: Char, descr: String) extends DatabaseModel {
   override def fields: Seq[Fields] = 
       Seq(Field(IntFieldType(), "num"),
           Field(CharFieldType(), "sym"),
           Field(StringFieldType(), "descr")
          ) 
}

my desired result should be something like this:

val m: DatabaseModel = Model(1, 'A', "First Name")

I have looked to similar questions

Generate companion object for case class with methods (field = method)

and so how I can extend that solution to achieve the desired result?

import scala.annotation.{StaticAnnotation, compileTimeOnly}
import scala.language.experimental.macros
import scala.reflect.macros.whitebox

object Macros {
  @compileTimeOnly("enable macro paradise")
  class GenerateCompanionWithFields extends StaticAnnotation {
    def macroTransform(annottees: Any*): Any = macro Macro.impl
  }

  object Macro {
    def impl(c: whitebox.Context)(annottees: c.Tree*): c.Tree = {
      import c.universe._
      annottees match {
        case (cls @ q"$_ class $tpname[..$_] $_(...$paramss) extends { ..$_ } with ..$_ { $_ => ..$_ }") :: Nil =>

          val newMethods = paramss.flatten.map {
            case q"$_ val $tname: $tpt = $_" =>
              q"def $tname(): String = ${tpt.toString}"
          }

          q"""
             $cls

             object ${tpname.toTermName} {
               ..$newMethods
             }
           """
      }
    }
  }
}

Solution

  • Try

    import scala.annotation.{StaticAnnotation, compileTimeOnly}
    import scala.language.experimental.macros
    import scala.reflect.macros.whitebox
    
    object Macros {
      @compileTimeOnly("enable macro paradise")
      class GetFromDB extends StaticAnnotation {
        def macroTransform(annottees: Any*): Any = macro GetFromDBMacro.impl
      }
    
      object GetFromDBMacro {
        def impl(c: whitebox.Context)(annottees: c.Tree*): c.Tree = {
          import c.universe._
          annottees match {
            case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: tail =>
              val fields = paramss.flatten.map {
                case q"$_ val $tname: $tpt = $_" => q"Field(${TermName(tpt.toString + "FieldType")}.apply(), ${tname.toString})"
              }
              q"""
                 $mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..${tq"DatabaseModel" :: parents} { $self =>
                   override def fields: _root_.scala.collection.immutable.Seq[Field] = _root_.scala.collection.immutable.Seq.apply(..$fields)
    
                   ..$stats
                 }
    
                 ..$tail
               """
          }
        }
      }
    }
    
    import Macros._
    
    object App {
      sealed trait DbModelFieldType
      case class IntFieldType() extends DbModelFieldType
      case class CharFieldType() extends DbModelFieldType
      case class StringFieldType() extends DbModelFieldType
    
      case class Field(fieldType: DbModelFieldType, fieldName: String)
    
      abstract class DatabaseModel {
        def fields: Seq[Field]
      }
    
      @GetFromDB
      case class Model(num: Int, sym: Char, descr: String)
    
    //Warning:scalac: {
    //  case class Model extends DatabaseModel with scala.Product with scala.Serializable {
    //    <caseaccessor> <paramaccessor> val num: Int = _;
    //    <caseaccessor> <paramaccessor> val sym: Char = _;
    //    <caseaccessor> <paramaccessor> val descr: String = _;
    //    def <init>(num: Int, sym: Char, descr: String) = {
    //      super.<init>();
    //      ()
    //    };
    //    override def fields: _root_.scala.collection.immutable.Seq[Field] = _root_.scala.collection.immutable.Seq.apply(Field(IntFieldType.apply(), "num"), Field(CharFieldType.apply(), "sym"), Field(StringFieldType.apply(), "descr"))
    //  };
    //  ()
    //}
    
      val m: DatabaseModel = Model(1, 'A', "First Name")
    
      def main(args: Array[String]): Unit = {
        println(
          m.fields //List(Field(IntFieldType(),num), Field(CharFieldType(),sym), Field(StringFieldType(),descr))
        )
      }
    }