Search code examples
scalainheritancetype-parameter

Type parameters and inheritance in Scala


Is there a simple way to return a concrete type in an override method? And what about creating an instance of a concrete implementation? And calling chained methods implemented in the concrete class, so they return a correct type, too? I have a solution (based on https://stackoverflow.com/a/14905650) but I feel these things should be simpler.

There are many similar questions, but everyone's case is a little different, so here is another example (shortened from https://github.com/valdanylchuk/saiml/tree/master/src/main/scala/saiml/ga). When replying, if possible, please check if the whole code block compiles with your suggested change, because there are subtle cascading effects. I could not make this work with the "curiously recurring template pattern", for example (not that I find it nicer).

import scala.reflect.ClassTag
import scala.util.Random

abstract class Individual(val genome: String) {
  type Self
  def this() = this("")  // please override with a random constructor
  def crossover(that: Individual): Self
}
class HelloGenetic(override val genome: String) extends Individual {
  type Self = HelloGenetic
  def this() = this(Random.alphanumeric.take("Hello, World!".length).mkString)
  override def crossover(that: Individual): HelloGenetic = {
    val newGenome = this.genome.substring(0, 6) + that.genome.substring(6)
    new HelloGenetic(newGenome)
  }
}
class Population[A <: Individual {type Self = A} :ClassTag]( val size: Int,
    tournamentSize: Int, givenIndividuals: Option[Vector[A]] = None) {
  val individuals: Vector[A] = givenIndividuals getOrElse
    Vector.tabulate(size)(_ => implicitly[ClassTag[A]].runtimeClass.newInstance.asInstanceOf[A])
  def tournamentSelect(): A = individuals.head  // not really, skipped
  def evolve: Population[A] = {
    val nextGen = (0 until size).map { _ =>
      val parent1: A = tournamentSelect()
      val parent2: A = tournamentSelect()
      val child: A = parent1.crossover(parent2)
      child
    }.toVector
    new Population(size, tournamentSize, Some(nextGen))
  }
}
class Genetic[A <: Individual {type Self = A} :ClassTag](populationSize: Int, tournamentSize: Int) {
  def optimize(maxGen: Int, maxMillis: Long): Individual = {
    val first = new Population[A](populationSize, tournamentSize)
    val optPop = (0 until maxGen).foldLeft(first) { (pop, _) => pop.evolve }
    optPop.individuals.head
  }
}

Solution

  • The CRTP version is

    abstract class Individual[A <: Individual[A]](val genome: String) {
      def this() = this("")  // please override with a random constructor
    
      def crossover(that: A): A
    }
    class HelloGenetic(override val genome: String) extends Individual[HelloGenetic] {
      def this() = this(Random.alphanumeric.take("Hello, World!".length).mkString)
      override def crossover(that: HelloGenetic): HelloGenetic = {
        val newGenome = this.genome.substring(0, 6) + that.genome.substring(6)
        new HelloGenetic(newGenome)
      }
    }
    class Population[A <: Individual[A] :ClassTag]( val size: Int,
        tournamentSize: Int, givenIndividuals: Option[Vector[A]] = None) {
      val individuals: Vector[A] = givenIndividuals getOrElse
        Vector.tabulate(size)(_ => implicitly[ClassTag[A]].runtimeClass.newInstance.asInstanceOf[A])
      def tournamentSelect(): A = individuals.head  // not really, skipped
    
      def evolve: Population[A] = {
        val nextGen = (0 until size).map { _ =>
          val parent1: A = tournamentSelect()
          val parent2: A = tournamentSelect()
          val child: A = parent1.crossover(parent2)
          child
        }.toVector
        new Population(size, tournamentSize, Some(nextGen))
      }
    }
    class Genetic[A <: Individual[A] :ClassTag](populationSize: Int, tournamentSize: Int) {
      def optimize(maxGen: Int, maxMillis: Long): Individual[A] = {
        val first = new Population[A](populationSize, tournamentSize)
        val optPop = (0 until maxGen).foldLeft(first) { (pop, _) => pop.evolve }
        optPop.individuals.head
      }
    }
    

    which compiles. For creating the instances, I'd suggest just passing functions:

    class Population[A <: Individual[A]](val size: Int,
        tournamentSize: Int, makeIndividual: () => A, givenIndividuals: Option[Vector[A]] = None) {
      val individuals: Vector[A] = givenIndividuals getOrElse
        Vector.fill(size)(makeIndividual())
      ...
    }
    

    If you want to pass them implicitly, you can easily do so:

    trait IndividualFactory[A] {
      def apply(): A
    }
    
    class HelloGenetic ... // remove def this() 
    object HelloGenetic {
      implicit val factory: IndividualFactory[HelloGenetic] = new IndividualFactory[HelloGenetic] {
        def apply() = new HelloGenetic(Random.alphanumeric.take("Hello, World!".length).mkString)
      }
    }
    class Population[A <: Individual[A]](val size: Int,
        tournamentSize: Int, givenIndividuals: Option[Vector[A]] = None)
        (implicit factory: IndividualFactory[A]) {
      val individuals: Vector[A] = givenIndividuals getOrElse
        Vector.fill(size)(factory())
      ...
    }