Search code examples
scalarecursionfunctional-programmingtail-recursion

Why use a helper function inside a recursive function?


In the book Functional Programming in Scala, in the context of explaining how recursion is often used in functional programming over imperative iteration, the authors show recursion via a factorial function using a helper function called "go" or "loop", and state that this is standard practice is functional scala programming:

...
def factorial(n: Int): Int = {
  @tailrec def go(n: Int, acc: Int): Int = {
    if (n <= 0 ) acc
    else go(n - 1, n*acc)
  }
  go(n, 1)
}

...but one could just as easily, if not more concisely define it without a helper function thus:

...
def factorial(n: Int): Int = {
  if (n <= 0) 1
    else n * factorial(n - 1)
}

My understanding is that accumulating values and avoiding mutation is achieved in recursion by leveraging the stack frame and "passing" return values to the previous stack frame. Here, the authors appear to using an explicit accumulator parameter for a similar purpose.

Is there an advantage in using helper functions to accumulate values like this, or are they using this example to show how recursion relates to imperative iteration by explicitly passing state to the helper function?


Solution

  • Tail recursion is an important optimization that makes recursive functions faster and avoids a "stack overflow" if the recursion is too deep.

    Recursion without tail recursion

    def factorial(n: Int): Int = {
      if (n <= 0) 1
        else n * factorial(n - 1)
    }
    

    What happens when I want to calculate factorial(10)? First, I calculate factorial(9); Then, I multiply the result by 10. This means that while I am calculating factorial(9), I need to keep a note somewhere: "Remember, when you're done with factorial(9), you still need to multiply by 10!".

    And then in order to calculate factorial(9), I must first calculate factorial(8), then multiply the result by 9. So I write a little note "Remember to multiply by 9 when you have the result of factorial(8).

    This goes on; finally I arrive at factorial(0), which is 1. By this time I have ten little notes that say "Remember to multiply by 1 when you're done with factorial(0)", "Remember to multiply by 2 when you're done with factorial(1)", etc.

    Those notes are called "the stack" because they are quite literally stacked on top of each other. If the stack gets too big, the program crashes with a "stack overflow".

    Tail recursion

    def factorial(n: Int): Int = {
      @tailrec def go(n: Int, acc: Int): Int = {
        if (n <= 0 ) acc
        else go(n - 1, n*acc)
      }
      go(n, 1)
    }
    

    The function go in this program is different. In order to calculate go(10, 1) you need to calculate go(9, 10); but when you have finished calculating go(9, 10), you don't need to do anything else! You can return the result directly. So there is no need to keep a little note "remember to multiply the result by 10 after the recursive call". The compiler optimizes this behaviour: instead of stacking the call to go(9, 10) on top of the call to go(10, 1), it replaces the call to go(10, 1) with the call to go(9, 10). Then it replaces the call to go(9, 10) with a call to go(8, 90). So the stack never increases during the recursion. This is called tail recursion because the recursive call is the last thing that happens in the execution of the function (in particular, the multiplication by 10 happens when evaluating the arguments).