Search code examples
scalainheritancetyping

Specify concrete type for methods in Scala trait


I want to define a method in a Scala trait where both a parameter to the method and the return type correspond to the same concrete class which extends the trait. I've tried something like the following:

trait A {
  def foo(obj: this.type): this.type
}

final case class B(val bar: Int) extends A {
  override def foo(obj: B): B = {
    B(obj.bar + this.bar)
  }
}

object Main {
  def main(args: Array[String]) = {
    val b1 = new B(0)
    val b2 = new B(0)
    val b3: B = b1.foo(b2)
  }
}

However, trying to compile this code gives the following error:

Test.scala:5: error: class B needs to be abstract. Missing implementation for:
  def foo(obj: B.this.type): B.this.type // inherited from trait A
case class B(val bar: Int) extends A {
           ^
Test.scala:6: error: method foo overrides nothing.
Note: the super classes of class B contain the following, non final members named foo:
def foo: ((obj: _1.type): _1.type) forSome { val _1: B }
  override def foo(obj: B): B = {
               ^
2 errors

There's obviously something I'm misunderstanding about the Scala type system here. The signature of foo in class B is what I want it to be, but I don't know how to correctly define the method in A (or if this is even possible). It seems like this question is asking something quite similar, but I don't immediately see how the answer applies in my situation.


Solution

  • The type annotation this.type means that you may only return this. So in that case you may not return another instance of B, the same holds for the method parameter.

    If this was just about the return type, a solution would be to require foo to return something of type A, the override method in B can specialize the return type to return B.

    However since you also have an argument which you want to be of the type of the subtype you could use a Self Recursive Type. The following example compiles and should do what you want.

      trait A[S <: A[S]] {
        def foo(obj: S): S
      }
    
      case class B(val bar: Int) extends A[B] {
        override def foo(obj: B): B = {
          B(obj.bar + 1)
        }
      }