I'm working on Scala 2.12.17.
Let's say I've a bunch of case classes :
case class TestOne(one: String)
case class TestTwo(one: String, two: String)
case class TestThree(one: String, two: String, three: String)
I also have these types :
trait Data{
val a: Int
}
case class DoubleInt(a: Int, b: Int) extends Data
case class SingleInt(a: Int) extends Data
And this function which converts Data
objects to String
:
def loadData(input: Data): String = {
input.a.toString
}
What I'm looking forward is to pass Data
object(s) to my case classe's apply method, then the apply method would use loadData
function in order to convert each passed Data
object into a String
to make an instance of my case class. E.g :
val dataOne: Data = SingleInt(1)
val dataTwo: Data = DoubleInt(1, 2)
val testOne = TestOne(dataOne)
val testTwo = TestTwo(dataOne, dataTwo)
val testThree = TestOne(dataOne, dataTwo, dataOne)
Basically, TestOne
apply
method would be :
def apply(one: Data): TestOne = {
new TestOne(loadData(one))
}
TestTwo
apply
method would be :
def apply(one: Data, two: Data): TestTwo= {
new TestTwo(loadData(one), loadData(two))
}
Is there any way to programatically generate those apply methods at compile time ?
I thought that macros or paradise annotations would be useful for this use case, but I'm too unexperienced with these topics to even know where to start :/
Should
val testThree = TestOne(dataOne, dataTwo, dataOne)
be val testThree = TestThree(dataOne, dataTwo, dataOne)
?
So you'd like to replace
val testOne = TestOne(loadData(dataOne))
val testTwo = TestTwo(loadData(dataOne), loadData(dataTwo))
val testThree = TestThree(loadData(dataOne), loadData(dataTwo), loadData(dataOne))
with just
val testOne = TestOne(dataOne)
val testTwo = TestTwo(dataOne, dataTwo)
val testThree = TestThree(dataOne, dataTwo, dataOne)
Mapping over a tuple or case class is a standard task for example for Shapeless
// libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.10"
import shapeless.poly.->
import shapeless.syntax.std.tuple._
object loadDataPoly extends (Data -> String)(loadData)
val testOne = TestOne(loadData(dataOne))
val testTwo = TestTwo.tupled((dataOne, dataTwo).map(loadDataPoly))
val testThree = TestThree.tupled((dataOne, dataTwo, dataOne).map(loadDataPoly))
If you want to hide load
at all you can define a generic method
import shapeless.{Generic, HList}
import shapeless.ops.traversable.FromTraversable
def make[A <: Product] = new PartiallyAppliedMake[A]
class PartiallyAppliedMake[A <: Product] {
def apply[L <: HList](data: Data*)(implicit
generic: Generic.Aux[A, L],
fromTraversable: FromTraversable[L]
): A = generic.from(fromTraversable(data.map(loadData)).get)
}
val testOne = make[TestOne](dataOne)
val testTwo = make[TestTwo](dataOne, dataTwo)
val testThree = make[TestThree](dataOne, dataTwo, dataOne)
val testThree_ = make[TestThree](dataOne, dataTwo, dataOne, dataOne) // fails at runtime
If you want make
to fail at compile time if the number of arguments is incorrect then the definition is a little more complicated
import shapeless.{Generic, HList, Nat}
import shapeless.ops.hlist.{Mapper, Length, Fill}
import shapeless.ops.function.FnFromProduct
def make[TestClass <: Product] = new PartillyAppliedMake[TestClass]
class PartillyAppliedMake[TestClass <: Product] {
def apply[StringHList <: HList, N <: Nat, DataHList <: HList]()(implicit
generic: Generic.Aux[TestClass, StringHList],
length: Length.Aux[StringHList, N],
fill: Fill.Aux[N, Data, DataHList],
mapper: Mapper.Aux[loadDataPoly.type, DataHList, StringHList],
fnFromProduct: FnFromProduct[DataHList => TestClass]
): fnFromProduct.Out =
fnFromProduct((l: DataHList) => generic.from(mapper(l)))
}
val testOne = make[TestOne]().apply(dataOne)
val testTwo = make[TestTwo]().apply(dataOne, dataTwo)
val testThree = make[TestThree]().apply(dataOne, dataTwo, dataOne)
val testThree_ = make[TestThree]().apply(dataOne, dataTwo, dataOne, dataOne) // fails at compile time
Or you can define a def macro
// libraryDependencies += scalaOrganization.value % "scala-reflect" % scalaVersion.value
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
def make[A](data: Data*): A = macro makeImpl[A]
def makeImpl[A: c.WeakTypeTag](c: blackbox.Context)(data: c.Tree*): c.Tree = {
import c.universe._
val A = weakTypeOf[A]
val strs = data.map(t => q"loadData($t)")
q"new $A(..$strs)"
}
// in a different subproject
val testOne = make[TestOne](dataOne) // TestOne(1)
val testTwo = make[TestTwo](dataOne, dataTwo) // TestTwo(1,1)
val testThree = make[TestThree](dataOne, dataTwo, dataOne) // TestThree(1,1,1)
val testThree_ = make[TestThree](dataOne, dataTwo, dataOne, dataOne) // doesn't compile: too many arguments (found 4, expected 3) ...
// scalacOptions += "-Ymacro-debug-lite"
//scalac: new App.TestOne(loadData(App.this.dataOne))
//scalac: new App.TestTwo(loadData(App.this.dataOne), loadData(App.this.dataTwo))
//scalac: new App.TestThree(loadData(App.this.dataOne), loadData(App.this.dataTwo), loadData(App.this.dataOne))
But if you really want to generate apply
methods in companion objects you can define macro annotaion (settings: Auto-Generate Companion Object for Case Class in Scala)
// addCompilerPlugin("org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full) // Scala 2.12
import scala.annotation.{StaticAnnotation, compileTimeOnly}
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
@compileTimeOnly("enable macro annotations")
class generateApply extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro GenerateApplyMacro.macroTransformImpl
}
object GenerateApplyMacro {
def macroTransformImpl(c: blackbox.Context)(annottees: c.Tree*): c.Tree = {
import c.universe._
def modify(cls: ClassDef, obj: ModuleDef): Tree = (cls, obj) match {
case (
q"$_ class $tpname[..$tparams] $_(...$paramss) extends { ..$_ } with ..$_ { $_ => ..$_ }",
q"$mods object $tname extends { ..$earlydefns } with ..$parents { $self => ..$body }"
) =>
val paramss1 = paramss.map(_.map {
case q"$mods val $tname: $_ = $_" => q"$mods val $tname: Data"
})
val argss = paramss.map(_.map {
case q"$_ val $tname: $_ = $_" => q"loadData($tname)"
})
val targs = tparams.map {
case q"$_ type $tpname[..$_] = $tpt" => tq"$tpname"
}
q"""
$cls
$mods object $tname extends { ..$earlydefns } with ..$parents { $self =>
def apply[..$tparams](...$paramss1): $tpname[..$targs] = {
new $tpname[..$targs](...$argss)
}
..$body
}
"""
}
annottees match {
case (cls: ClassDef) :: (obj: ModuleDef) :: Nil => modify(cls, obj)
case (cls: ClassDef) :: Nil => modify(cls, q"object ${cls.name.toTermName}")
}
}
}
// in a different subproject
@generateApply case class TestOne(one: String)
@generateApply case class TestTwo(one: String, two: String)
@generateApply case class TestThree(one: String, two: String, three: String)
val testOne = TestOne(dataOne) // TestOne(1)
val testTwo = TestTwo(dataOne, dataTwo) // TestTwo(1,1)
val testThree = TestThree(dataOne, dataTwo, dataOne) // TestThree(1,1,1)
//scalac: {
// case class TestOne extends scala.Product with scala.Serializable {
// <caseaccessor> <paramaccessor> val one: String = _;
// def <init>(one: String) = {
// super.<init>();
// ()
// }
// };
// object TestOne extends scala.AnyRef {
// def <init>() = {
// super.<init>();
// ()
// };
// def apply(one: Data): TestOne = new TestOne(loadData(one))
// };
// ()
//}
//scalac: {
// case class TestTwo extends scala.Product with scala.Serializable {
// <caseaccessor> <paramaccessor> val one: String = _;
// <caseaccessor> <paramaccessor> val two: String = _;
// def <init>(one: String, two: String) = {
// super.<init>();
// ()
// }
// };
// object TestTwo extends scala.AnyRef {
// def <init>() = {
// super.<init>();
// ()
// };
// def apply(one: Data, two: Data): TestTwo = new TestTwo(loadData(one), loadData(two))
// };
// ()
//}
//scalac: {
// case class TestThree extends scala.Product with scala.Serializable {
// <caseaccessor> <paramaccessor> val one: String = _;
// <caseaccessor> <paramaccessor> val two: String = _;
// <caseaccessor> <paramaccessor> val three: String = _;
// def <init>(one: String, two: String, three: String) = {
// super.<init>();
// ()
// }
// };
// object TestThree extends scala.AnyRef {
// def <init>() = {
// super.<init>();
// ()
// };
// def apply(one: Data, two: Data, three: Data): TestThree = new TestThree(loadData(one), loadData(two), loadData(three))
// };
// ()
//}