Search code examples
scalatypescompilationtype-inferencetype-systems

Scala Puzzle: enforcing that two function arguments are of the same type AND both are a subtype of a given class


How can I enforce that trickyMethod's arguments are the same at compile time but at the same time also have the common super type Fruit ?

So in other words, tricky.trickyMethod(new Banana,new Apple) should not compile.

I am sure there must be a simple solution but I just spent 1 hour searching for the answer and still have no idea :(

I tried implicit evidence with <:< but I could not get it to work.

class Fruit
class Apple extends Fruit
class Banana extends Fruit

class TrickyClass[T<:Fruit]{
  def trickyMethod(p1:T,p2:T)= println("I am tricky to solve!")
}

object TypeInferenceQuestion extends App{
  val tricky=new TrickyClass[Fruit]()
  tricky.trickyMethod(new Apple,new Apple) //this should be OK
  tricky.trickyMethod(new Banana,new Banana) //this should be OK
  tricky.trickyMethod(new Banana,new Apple) //this should NOT compile

}

EDIT :

Thank you for the answers !

Follow up (more general) question:

This second example is a more general case of the first example.

class Fruit

class Apple extends Fruit
class Banana extends Fruit

class TrickyClass[T]{
  def trickyMethod[S<:T](p1:S,p2:S)= println("I am tricky to solve!")
}

object TypeInferenceQuestion extends App{
  val tricky=new TrickyClass[Fruit]()
  tricky.trickyMethod(new Apple,new Apple) //this should be OK
  tricky.trickyMethod(new Banana,new Banana) //this should be OK
  tricky.trickyMethod(new Banana,new Apple) //this should NOT compile

}

Solution

  • You could do :

    class Tricky[T] {
      def trickyMethod[S1<:T,S2<:T](s1:S1,s2:S2)(implicit ev: S1=:=S2) = println()
    }
    
    
    scala> val t = new Tricky[Seq[Int]]
    t: Tricky[Seq[Int]] = Tricky@2e585191
    
    scala> t.trickyMethod(List(1),List(1))
    //OK
    
    scala> t.trickyMethod(List(1),Seq(1))
    <console>:10: error: Cannot prove that List[Int] =:= Seq[Int].
                  t.trickyMethod(List(1),Seq(1))