Search code examples
scalafunctional-programmingpolymorphismtypeclassimplicit

How to change base class field in functional fashion in scala


Say I have this hierarchy

trait Base {
  val tag: String
}

case class Derived1(tag: String = "Derived 1") extends Base
case class Derived2(tag: String = "Derived 2") extends Base
//etc ...

and I want to define method with following signature

def tag[T <: Base](instance: T, tag: String): T

that returns an instance of type T with modified tag: String. So when e.g. a Derived1 instance is passed in a modified instance of the same type is returned.

This goal could be easily accomplished by using mutable tag variable var tag: String. How to achieve desired behaviour using scala and functional programming?

My thought:

I could create a type class and its instances

trait Tagger[T] {
  def tag(t: T, state: String): T
}

implicit object TaggerDerived1 extends Tagger[Derived1] {
  override def tag(t: Derived1, state: String): Derived1 = ???
}

implicit object TaggerDerived2 extends Tagger[Derived2] {
  override def tag(t: Derived2, state: String): Derived2 = ???
}

implicit object TaggerBase extends Tagger[Base] {
  override def tag(t: Base, state: String): Base = ???
}

and a method

def tag[T <: Base](instance: T, tag: String)(implicit tagger: Tagger[T]): T = tagger.tag(instance, tag)

This is not ideal, because first of all user must be aware of this when defining their own derived classes. When not defining one, the implicit resolution would fall back to base implementation and narrow the returning type.

case class Derived3(tag: String = "Derived 3") extends Base


tag(Derived3(), "test") // falls back to `tag[Base](...)`

Now I am leaning towards using mutable state by employing var tag: String. However, I would love to hear some opinions how to resolve this purely functionally in scala.


Solution

  • You can derive your type class Tagger and then the users will not have to define its instances for every new case class extending Base

    // libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.10"
    import shapeless.labelled.{FieldType, field}
    import shapeless.{::, HList, HNil, LabelledGeneric, Witness}
    
    trait Tagger[T] {
      def tag(t: T, state: String): T
    }
    
    trait LowPriorityTagger {
      implicit def notTagFieldTagger[K <: Symbol : Witness.Aux, V, T <: HList](implicit
        tagger: Tagger[T]
      ): Tagger[FieldType[K, V] :: T] =
        (t, state) => t.head :: tagger.tag(t.tail, state)
    }
    
    object Tagger extends LowPriorityTagger {
      implicit def genericTagger[T <: Base with Product, L <: HList](implicit
        generic: LabelledGeneric.Aux[T, L],
        tagger: Tagger[L]
      ): Tagger[T] = (t, state) => generic.from(tagger.tag(generic.to(t), state))
    
      implicit val hnilTagger: Tagger[HNil] = (_, _) => HNil
    
      implicit def tagFieldTagger[T <: HList]:
        Tagger[FieldType[Witness.`'tag`.T, String] :: T] = 
        (t, state) => field[Witness.`'tag`.T](state) :: t.tail
    }
    
    case class Derived1(tag: String = "Derived 1") extends Base
    case class Derived2(tag: String = "Derived 2") extends Base
    case class Derived3(i: Int, tag: String = "Derived 3", s: String) extends Base
    
    tag(Derived1("aaa"), "bbb") // Derived1(bbb)
    tag(Derived2("ccc"), "ddd") // Derived2(ddd)
    tag(Derived3(1, "ccc", "xxx"), "ddd") // Derived3(1,ddd,xxx)
    

    Alternatively for single-parameter case classes you can constrain T so that it has .copy

    import scala.language.reflectiveCalls
    def tag[T <: Base {def copy(tag: String): T}](instance: T, tag: String): T =
      instance.copy(tag = tag)
    

    For multi-parameter case classes it's harder to express in types the existence of .copy because the method signature becomes unknown (to be calculated).

    So you can make tag a macro

    // libraryDependencies += scalaOrganization.value % "scala-reflect" % scalaVersion.value
    import scala.language.experimental.macros
    import scala.reflect.macros.blackbox
    
    def tag[T <: Base](instance: T, tag: String): T = macro tagImpl
    
    def tagImpl(c: blackbox.Context)(instance: c.Tree, tag: c.Tree): c.Tree = {
      import c.universe._
      q"$instance.copy(tag = $tag)"
    }
    

    Or you can use runtime reflection (Java or Scala, using Product functionality or not)

    import scala.reflect.{ClassTag, classTag}
    import scala.reflect.runtime.{currentMirror => rm}
    import scala.reflect.runtime.universe.{TermName, termNames}
    
    def tag[T <: Base with Product : ClassTag](instance: T, tag: String): T = {
        // Product
      val values = instance.productElementNames.zip(instance.productIterator)
        .map {case fieldName -> fieldValue => if (fieldName == "tag") tag else fieldValue}.toSeq
    
        // Java reflection
      // val clazz = instance.getClass
      // clazz.getMethods.find(_.getName == "copy").get.invoke(instance, values: _*).asInstanceOf[T]
      // clazz.getConstructors.head.newInstance(values: _*).asInstanceOf[T]
    
        // Scala reflection
      val clazz = classTag[T].runtimeClass
      val classSymbol = rm.classSymbol(clazz)
      // val copyMethodSymbol = classSymbol.typeSignature.decl(TermName("copy")).asMethod
      // rm.reflect(instance).reflectMethod(copyMethodSymbol)(values: _*).asInstanceOf[T]
      val constructorSymbol = classSymbol.typeSignature.decl(termNames.CONSTRUCTOR).asMethod
      rm.reflectClass(classSymbol).reflectConstructor(constructorSymbol)(values: _*).asInstanceOf[T]
    }