Search code examples
scalaoopinheritancemultiple-inheritancecase-class

Can Scala case classes match in both an inherited and non-inherited function?


Working through the Essential Scala book, I find myself playing with the linked list example. I have wound up with a situation where I have a superclass, for lists of anything, and a subclass for lists of Ints. I can ask for the length of either type of list, and the sum of a list of ints. However, I am having some trouble modelling this using case classes. The natural inheritance structure is as follows:

         List
     ↙    ↓    ↘
Pair   IntList   End
 ↓    ↙       ↘   ↓
IntPair       IntEnd

In List, I would like to be able to match on Pair VS End in List, which I can do by making Pair and End case classes. I use it in the sum function like this:

  final def length(counter: Int=0): Int =
    this match {
      case Pair(_, tail:List[A]) => tail.length(counter+1)
      case End() => counter

However, I would like to be able to match on IntPair VS IntEnd in IntList too. The natural implementation in IntList looks like this:

  final def sum(counter: Int=0): Int =
    this match {
      case IntPair(head, tail) => tail.sum(head + counter)
      case IntEnd() => counter

However, because Pair is already a case object, IntPair cannot be one too. But if it is not, then we can't match in sum. If we don't inherit, then length stops working because the match case doesn't know about IntPair/IntEnd.

I've made a Github gist with the full Scala worksheet, if that is of help. It can be run with scala -nc inheritance.sc.

What approach should I take instead of this one?


Solution

  • This question is actually pretty good to show why traditional sub-type polymorphism (usually modelled thought inheritance on OOP languages) is not enough to naturally implement some use cases.

    So lets take a look to other alternatives. In this case, I will cover: implicit evidence, extension methods and typeclasses (which is like a more powerful and flexible combination of the previous two).

    First lets put a common definition a List.

    sealed trait MyList[+A] extends Product with Serializable {
      final def :!:[B >: A](b: B): MyList[B] = new :!:(b, this)
    
      final def length: Int = {
        @annotation.tailrec
        def loop(remaining: MyList[A], acc: Int): Int =
          remaining match {
            case _ :!: tail => loop(remaining = tail, acc + 1)
            case MyNil => acc
          }
        loop(remaining = this, acc = 0)
      }
    }
    final case class :!:[+A](head: A, tail: MyList[A]) extends MyList[A]
    final case object MyNil extends MyList[Nothing]
    

    Implicit evidence.

    The idea is simple, a List provides a sum method as long as its elements are Ints.
    That is the intuition and we can just encode that:

    sealed trait MyList[+A] extends Product with Serializable {
      // ...
    
      // B required due variance.
      final def sum[B >: A](implicit ev: B =:= Int): Int = {
        @annotation.tailrec
        def loop(remaining: MyList[B], acc: Int): Int =
          remaining match {
            case i :!: tail => loop(remaining = tail, acc + i)
            case MyNil => acc
          }
        loop(remaining = this, acc = 0)
      }
    }
    

    This encoding allows us to call sum on any List as long as the compiler can prove that the elements of that List are Ints, if not the method won't be available.

    Extension methods.

    We can accomplish exactly the same behaviour as the implicit evidence by adding sum as an extension method for Lists of Ints only.
    This can be helpful if you do not control the datatype, or even if you control it but you have a lot of methods that require some conditions so you can split the code into multiple files, and even allow the user to opt-in for those extensions.

    implicit class IntListOps(private val intList: MyList[Int]) extends AnyVal {
      def sum: Int = {
        @annotation.tailrec
        def loop(remaining: MyList[Int], acc: Int): Int =
          remaining match {
            case i :!: tail => loop(remaining = tail, acc + i)
            case MyNil => acc
          }
        loop(remaining = intList, acc = 0)
      }
    }
    

    Just as before, this encoding allows us to call the sum method into any List as long as the compiler can prove that the elements of the List are Ints (and that the extension method is in scope, usually through an import).

    Typeclasses

    Finally, one of the most powerful patterns of Functional Programming languages, the typeclasses.

    In this case, we would need to add a lot of "boilerplate" code, as well as using extension methods. However, it must be noted that this allows us to extend this sum method to work on ant type that is summable, like Double or BigInt.

    trait Summable[A] {
      def sum(a1: A, a2: A): A
      def zero: A
    }
    
    object Summable {
      implicit final val IntSummable: Summable[Int] =
        new Summable[Int] {
          override def sum(i1: Int, i2: Int): Int = i1 + i2
          override val zero: Int = 0
        }
      
      object syntax {
        implicit class SummableOps[A](private val a1: A) extends AnyVal {
          @inline final def |+|(a2: A)(implicit ev: Summable[A]): A = ev.sum(a1, a2)
        }
        
        implicit class SummableListOps[A](private val list: MyList[A]) extends AnyVal {
          @inline final def sum(implicit ev: Summable[A]): A = {
            @annotation.tailrec
            def loop(remaining: MyList[A], acc: A): A =
              remaining match {
                case i :!: tail => loop(remaining = tail, acc |+| i)
                case MyNil => acc
              }
            loop(remaining = list, acc = ev.zero)
          }
        }
      }
    }
    

    Again, the compiler will allow us to call sum on a List as long as it can prove that their elements are summable. And, since the implicit scope is quite flexible, you can easily make custom types Summable.

    Also, we could even generalize this even more, what about we provide sum on any type that can be iterated on?
    This is exactly what cats does, it provides (between another amount of very useful things) two abstractions Monoid (which is the mathematical name for Summable) and Foldable (which represents the possibility of iterating something), which combined give us the combineAll method (and may other powerful abstractions like foldMap) on any "collection" of "combinable" elements.


    While each of those techniques is powerful on its own, the real power of Scala is that it allows you to mix them to get the API you desire.
    For example, the std List provides the sum method using an implicit evidence but of the existence of a typeclass.

    Also, there are other kinds of polymorphism techniques like duck typing (structural types), which I personally do not like but it's helpful to known about it.
    In case you are interested, I once wrote about some of those techniques and their differences in Scala.


    You can see the full code here