Search code examples
scalagenericstreepattern-matchingtype-bounds

Why can't I return a concrete subtype of A if a generic subtype of A is declared as return parameter?


abstract class IntTree
object Empty extends IntTree
case class NonEmpty(elem: Int, left: IntTree, right: IntTree) extends IntTree

def assertNonNegative[S <: IntTree](t: S): S = {
  t match {
    case Empty => Empty  // type mismatch, required: S, found: Empty.type
    case NonEmpty(elem, left, right) =>
      if (elem < 0) throw new Exception
      else NonEmpty(elem, assertNonNegatve(left), assertNonNegative(right)) // req: S, fd: NonEmpty.type
  }
}

This is my failed attempt of implementing the function with signature def assertNonNegative[S <: IntTree](t: S): S. Other than changing the signature to def assertNonNegative(t: IntTree): IntTree, I couldn't find a way to implement it.

Relevance of example:
In a video about subtyping and generics(4.4) in the course "Functional Programming Principles in Scala", Martin Odersky uses practically the same example (IntSet instead of IntTree) and says that this signature can be used to express that the function takes Empty to Empty and NonEmpty to NonEmpty. He says that the other signature is fine in most situations but if needed, the one with upper bounded S can be a more precise option. However, he does not show an implementation of the function.

What am I missing here?


Solution

  • Method's right hand side (pattern matching)

    t match {
      case Empty => Empty 
      case NonEmpty(elem, left, right) =>
        if (elem < 0) throw new Exception
        else NonEmpty(elem, assertNonNegatve(left), assertNonNegative(right)) 
    }
    

    means to check at runtime whether t is an instance of class Empty$ (object Empty) and then choose the first branch or otherwise whether t is an instance of class NonEmpty and then choose the second branch.

    Signature

    def assertNonNegative[S <: IntTree](t: S): S
    

    means to check at compile time that for every type S, which is a subtype of type IntTree, if the method accepts parameter t of type S then the method returns a value of type S.

    The code doesn't compile because definition of the method doesn't correspond to its signature. Subclasses of IntTree are NonEmpty and Empty (object). If IntTree is not sealed you can create its subclasses different from Empty and NonEmpty, you can even create them dynamically at runtime. But let's even suppose that IntTree is sealed and Empty and NonEmpty are its only subclasses.

    The thing is that there are a lot of subtypes of IntTree (classes and types are different): IntTree, Empty.type, NonEmpty, Nothing, Null, Empty.type with NonEmpty, NonEmpty with SomeType, Empty.type with SomeType, IntTree with SomeType, T (type T <: IntTree), x.type (val x: IntTree = ???) etc. and for all of them condition (t: S): S must be fulfilled.

    Obviously it's not true. For example we can take t = Empty.asInstanceOf[Empty.type with Serializable]. It has type Empty.type with Serializable. At runtime it matches class Empty (object) so the first branch is selected. But at compile time we don't know this yet, how can you guarantee at compile time that both Empty and NonEmpty that are returned have type Empty.type with Serializable?

    Type mismatch on abstract type used in pattern matching

    One way to fix assertNonNegative is to write honest monomorphic

    def assertNonNegative(t: IntTree): IntTree = {
      t match {
        case Empty => Empty
        case NonEmpty(elem, left, right) =>
          if (elem < 0) throw new Exception
          else NonEmpty(elem, assertNonNegative(left), assertNonNegative(right))
      }
    }
    

    another is to pretend that polymorphic signature is correct

    def assertNonNegative[S <: IntTree](t: S): S = {
      (t match {
        case Empty => Empty
        case NonEmpty(elem, left, right) =>
          if (elem < 0) throw new Exception
          else NonEmpty(elem, assertNonNegative(left), assertNonNegative(right))
      }).asInstanceOf[S]
    }
    

    the third is to use type tags

    def assertNonNegative[S <: IntTree : TypeTag](t: S): S = {
      t match {
        case Empty if typeOf[S] == typeOf[Empty.type] => Empty.asInstanceOf[S]
        case NonEmpty(elem, left, right) if typeOf[S] == typeOf[NonEmpty] =>
          if (elem < 0) throw new Exception
          else NonEmpty(elem, assertNonNegative(left), assertNonNegative(right)).asInstanceOf[S]
        case _ => ???
      }
    }
    

    the fourth is to make ADT more type-level

    sealed trait IntTree
    object Empty extends IntTree
    case class NonEmpty[L <: IntTree, R <: IntTree](elem: Int, left: L, right: R) extends IntTree
    

    and define type class

    def assertNonNegative[S <: IntTree](t: S)(implicit ann: AssertNonNegative[S]): S = ann(t)
    
    trait AssertNonNegative[S <: IntTree] {
      def apply(t: S): S
    }
    object AssertNonNegative {
      implicit val empty: AssertNonNegative[Empty.type] = { case Empty => Empty }
      implicit def nonEmpty[L <: IntTree : AssertNonNegative, 
                            R <: IntTree : AssertNonNegative]: AssertNonNegative[NonEmpty[L, R]] = {
        case NonEmpty(elem, left, right) =>
          if (elem < 0) throw new Exception
          else NonEmpty(elem, assertNonNegative(left), assertNonNegative(right))
      }
    }
    

    Soundness of type system means that sometimes we reject some programs at compile time, while they can't go wrong at runtime. For example

    val x: Int = if (true) 1 else "a"