Search code examples
scalarecursionfunctional-programmingtail-recursionpurely-functional

Scala: Why this function is not tail recursive?


I have such implementation of Merge Sort:

import scala.annotation.tailrec

object MergeSort {
  def sortBy[T]: ((T, T) => Int) => Seq[T] => Seq[T] = comparator => seqToSort => {
    @tailrec
    def merge(xs : Seq[T], ys : Seq[T], accum : Seq[T] = Seq()) : Seq[T] = (xs, ys) match {
      case (Seq(), _) => ys ++ accum
      case (_, Seq()) => xs ++ accum
      case (x::rx, y::ry) =>
        if(comparator(x, y) < 0)
          merge(xs, ry, y +: accum)
        else
          merge(rx, ys, x +: accum)
    }

    @tailrec
    // Problem with this function
    def step : Seq[Seq[T]] => Seq[T] = {
      case Seq(xs) => xs
      case xss =>
        val afterStep = xss.grouped(2).map({
          case Seq(xs) => xs
          case Seq(xs, ys) => merge(xs, ys)
        }).toSeq
        // Error here
        step(afterStep)
    }

    step(seqToSort.map(Seq(_)))
  }
}

It does not compile. It says that recursive call in step function is not in tail position. But it IS in tail position. Is there any way to fix it without trampoline?


Solution

  • The reason for that, is that step is a function that returns a function of signature: Seq[Seq[T]] => Seq[T]. So the recursive call doesn't call the same method directly, but obtains this function first and then calls it for given argument, which is not tail recursive.

    To solve this error you must declare step this way:

    @tailrec
    def step(seq: Seq[Seq[T]]): Seq[T] = seq match {
      ...
    }