Search code examples
scalatypespolymorphismpath-dependent-typetype-members

Enforce that all inputs of some type `T` to methods of a trait `F` have been previously produced by `F` itself


Suppose that I have some trait Foo, for example:

trait Foo {
  def bar: List[Int]
  def baz(i: Int): Unit
}

I want to enforce at compile time that all inputs passed to baz have been previously produced by bar. For example, if this is an implementation of Foo:

object A extends Foo {
  def bar = List(2, 3, 5, 7)
  def baz(i: Int): Unit = {
    if (bar contains i) println("ok")
    else println("not good")
  }
}

then I want to enforce that only single-digit primes can be passed to baz. This obviously doesn't work if the input type of baz is known to be Int, because it allows me to instantiate all kinds of integers that are not prime and not between 0 and 9:

val okValues: List[Int] = A.bar
A.baz(okValues(1)) // ok
A.baz(3)           // ok (but dangerous! `3` appeared out of nowhere!)
A.baz(42)          // not good

How can I enforce that only the values previously produced by bar can be passed to baz?


What doesn't work

Converting Int to a type member of Foo doesn't help, because it's instantiated to the concrete type Int in the implementation A of Foo:

trait Foo {
  type T
  def bar: List[T]
  def baz(t: T): Unit
}

object A extends Foo {
  type T = Int 
  def bar = List(2, 3, 4, 5)
  def baz(i: Int): Unit = {
    if (bar contains i) println("ok")
    else println("not good")
  } 
}

A.baz(42) // not good

Solution

  • Here is one solution that relies on replacing a concrete type Int by an abstract type member T, and then simply not exposing the concrete implementation of T by Int:

    • replace the concrete type Int by an abstract type member T
    • move methods bar and baz and the type T into an inner trait YFoo
    • inside Foo, provide a method that produces YFoo, but does not expose what T is.

    In code:

    trait Foo {
      trait YFoo {
        type T
        def bar: List[T]
        def baz(i: T): Unit
      }
      def yFoo: YFoo
    }
    
    object B extends Foo {
      def yFoo: YFoo = new YFoo {
        type T = Int
        def bar: List[Int] = List(2, 3, 5, 7)
        def baz(i: Int): Unit = {
          if (bar contains i) println("ok")
          else println("not good")
        }
      }
    }
    
    val a = B.yFoo
    val okValues: List[a.T] = a.bar
    a.baz(okValues(1)) // ok
    // a.baz(3)           // dangerous stuff doesn't compile!
    // found   : Int(3)
    // required: a.T
    
    // a.baz(42)          // invalid stuff also doesn't compile
    // found   : Int(42)
    // required: a.T
    

    Now all the dangerous / invalid stuff doesn't even compile. The only values that you can pass to baz are those from the "list of certified values" produced by bar.