Search code examples
algorithmhaskellrecursioncombinationstail-call-optimization

How I could use tail call optimisation on this combination function?


My exercise, that you can see here, says that I need to implement a recursive version of C(n, k).

That's my solution:

module LE1.Recursao.Combinacao where

combina :: Integral a => a -> a -> a
combina n k | k == 0         = 1
            | n == k         = 1
combina n k | k > 0 && n > k = (combina (n - 1) k) + (combina (n - 1) (k - 1))
combina _ _                  = 0

Now I want to create a tail-recursive version of this function, so I don't get stack overflows for large numbers and also calculate the combination quicker!

I'm new to tail call optimisation, but I did that in Elixir for the Fibonacci series:

defmodule Fibonacci do
  def run(n) when n < 0, do: :error
  def run(n), do: run n, 1, 0
  def run(0, _, result), do: result
  def run(n, next, result), do: run n - 1, next + result, next
end

I understand this Fibonacci code and I think that the combination algorithm isn't too different, but I don't know how to start.


Solution

  • Daniel Wagner writes

    combina n k = product [n-k+1 .. n] `div` product [1 .. k]
    

    This can get rather inefficient for large n and medium-size k; the multiplications get huge. How might we keep them smaller? We can write a simple recurrence:

    combina :: Integer -> Integer -> Integer
    combina n k
      | k > n || k < 0 = 0
      | otherwise = combina' n k'
      where
        -- C(n,k) and C(n,n-k) are the same,
        -- so we choose the cheaper one.
        k' = min k (n-k)
    
    -- Assumes 0 <= k <= n
    combina' _n 0 = 1
    combina' n k
      = -- as above
      product [n-k+1 .. n] `div` product [1 .. k]
      = -- expanding a bit
      (n-k+1) * product [n-k+2 .. n] `div` (product [1 .. k-1] * k)
      = -- Rearranging, and temporarily using real division
      ((n-k+1)/k) * (product [n-(k-1)+1 .. n] / product [1 .. k-1]
      = -- Daniel's definition
      ((n-k+1)/k) * combina' n (k-1)
      = -- Rearranging again to go back to integer division
      ((n-k+1) * combina' n (k-1)) `quot` k
    

    Putting that all together,

    combina' _n 0 = 1
    combina' n k = ((n-k+1) * combina' n (k-1)) `quot` k
    

    Just one problem remains: this definition is not tail recursive. We can fix that by counting up instead of down:

    combina' n k0 = go 1 1
      where
        go acc k
          | k > k0 = n `seq` acc
          | otherwise = go ((n-k+1)*acc `quot` k) (k + 1)
    

    Don't worry about the n `seq` ; it's of very little consequence.

    Note that this implementation uses O(min(k,n-k)) arithmetic operations. So if k and n-k are both very large, it will take a long time. I don't know if there's any efficient way to get exact results in that situation; I believe that binomial coefficients of that sort are usually estimated rather than calculated precisely.