Search code examples
scalapattern-matchingtraitscase-class

Pattern Matching Case Classes with Common Trait


I am writing some code that uses pattern matching. In testing I found a strange result:

object Example extends App {

  trait Human {
    def sing(): Unit
  }

  case class Son(name: String) extends Human {
    override def sing(): Unit = println("son " + name)
  }

  case class Daughter(name: String) extends Human {
    override def sing(): Unit = println("daughter " + name)
  }
  
  val jack = Son("jack")
  val sonia = Daughter("sonia")

  def f1(lst: List[Human]) = {
    lst match {
      case a: List[Son] => println("human is son")
      case b: List[Daughter] => println("human is daughter")
    }
  }

  f1(List(jack))
  f1(List(sonia))
}

Both these print "human is a son". Is there a way around this? I can see the compiler matches both Son and Daughter to Human. But is there a way to get it to distinguish between the two?


Solution

  • It looks like you really need to refactor your design. You shouldn't need to check the type elements in a list at runtime - you can have an overridden method if you want dynamic dispatch, or you can use separate methods for List[Son]s and List[Daughter].

    If you really want to make sure that all of the elements of the list are sons/daughters, you can use forall

    def f1(lst: List[Human]) =
        if (lst.forall(_.isInstanceOf[Son])) println("human is son")
        else if (lst.forall(_.isInstanceOf[Daughter])) println("human is daughter")
    

    This isn't great, though. What if there's a list with Sons and Daughters, or maybe some third type altogether?

    I'd recommend 2 different methods - one for Sons and one for Daughters. I would also make your Human trait sealed so no new implementations have to be dealt with.

    def f1(lst: List[Daughter]) = println("human is daughter")
    //DummyImplicit is a workaround for type erasure, otherwise, they'd have the same signature
    def f1(lst: List[Son])(implicit d: DummyImplicit) = println("human is son")
    

    You could also use typeclasses, although it doesn't seem worth it here

    def f1[A <: Human](lst: List[A])(implicit idr: Identifier[A]) =
      idr.identify(lst)
    
    sealed trait Identifier[A <: Human] {
      def identify(lst: List[A]): String
    }
    object Identifier {
      implicit val sonIdentifier: Identifier[Son] = _ => "human is son"
      implicit val daughterIdentifier: Identifier[Daughter] = _ => "human is daughter"
    }