Search code examples
recursionocamlbinary-treetail-recursion

Tail-recursively, maximum element in a binary tree in OCaml


I am practicing tail-recursion and I want to, given the type

type 'a tree = Leaf of 'a | Pair of 'a tree * 'a tree

and a function that that finds the maximum element in a binary tree

let rec tree_max t = match t with 
    | Leaf v -> v 
    | Pair (l,r) -> max (tree_max l) (tree_max r)

make the above function tail-recursive.


I have tried

let rec tree_max t acc= match t with 
    | Leaf v -> acc
    | Pair (l,r) -> (max (tree_max l) (tree_max r))::ACC

and I have also tried

let rec tree_max t acc= match t with 
    | Leaf v -> acc
    | Pair (l,r) -> if (max (tree_max l) (tree_max l) = acc) then tree_max l acc else tree_max r acc

but they all yield syntax errors. Does anyone have an idea how to implement this?


Solution

  • TL;DR; this question is a little bit deeper than it might look and since I don't want to make someone's homework, I decided to write a guide on turning recursive functions into tail-recursive ones. It turned to be a little bit larger than I expected :)

    The ultimate guide on (tail) recursion

    What is tail recursion?

    It is not always easy to identify when a function is tail-recursive and when it is not. The key idea is that if you do some work after the recursive call, then your function is not tail-recursive. And to make it tail-recursive, we need to pass enough information to the recursive call so that the latter can compute the result without our intervention, i.e., the result of the function becomes the result of the recursive call.

    Here is a simple example, a non-tail-recursive length function that computes the length of a list,

    let rec length = function
      | [] -> 0
      | _ :: xs -> 1 + length xs
    

    We can see that in 1 + length xs the computer first evaluates length xs and then waits for its result just to add one to it. Obviously, we can pass the current length to the recursive call and let the base case return it, e.g.,

    let rec length acc = function
      | [] -> acc
      | _ :: xs -> length (acc+1) xs
    

    So, as you may see, the trick is to pass the information down the recursion using a parameter. An immediate caveat is that we now have an extra parameter, the acc (stands for accumulator). The convention is to hide this parameter from the end user of our interface in a nested function. I commonly call this function loop, e.g.,

    let length xs = 
      let rec loop acc = function
        | [] -> acc
        | _ :: xs -> loop (acc+1) xs in
      loop 0 xs
    

    In the length example we were able to improve our algorithm from O(N) in the memory size to O(1). Indeed in a non-tail-recursive version, the compiler created a chain of stack calls equal to the length of the list, with each slot storing the length of the sublist. Essentially, the compiler built a singly-linked list of immediate results and then reduced it with the + operator. It is a pretty non-efficient approach for summing n numbers.

    But sometimes, we can't reduce our recursion parameter to a single scalar value, so building a the result in a stack could be beneficial, consider the map function that applies a function to each element of a list and build a new list of results,

    let rec map f = function
      | [] -> []
      | x :: xs -> f x :: map f xs
    

    There is no way to reduce this function from O(N) to O(1) as we have to build and return a new map. But it consumes stack, which is a scarce1 resource, unlike the regular, heap, memory. So let's turn this function into the tail-recursive one. Again, we have to pass the necessary information down the recursion so that when we reach the recursion bottom we can build the answer,

    let map f xs =
      let rec loop acc = function
        | [] -> List.rev acc
        | x :: xs -> loop (x::acc) xs in
      loop [] xs
    

    As you can see, we again employed the trick with acc and a nested loop function to hide an extra parameter. Unlike the previous example, we use a list instead of an integer to accumulate the result. Since we were pretending to the list, we ended up in a result in the reversed order, so we have to reverse it back at the bottom. We could append to the list instead, but that will result in a very inefficient code, as appending to a singly-linked list is an O(N) operation, therefore if we will repeat it N times we will get O(N^2) complexity.

    Recursion on trees

    In the above examples we had more or less obvious choices for the accumulator, but what about tree-like structures? With trees, we had to recurse several times on each step, e.g.,

    let max_elt t = 
      let rec loop current_max t = match t with
        | Leaf x -> max current_max x (* so far so good *)
        | Tree (l,r) -> <??> in
      loop <??> t
    

    as we can see converting our function into the accumulator-style recursion with an obvious choice for the accumulator as integer didn't help. First of all, we now have to make a dubious choice of the initial max value, next, and this is the main problem is that we have to branches l and r and we can't recurse to both of them in the same recursive call... or we can?

    Yes, we can, and in fact, there are several solutions. Let's start with the accumulator-style recursion. Since we want to recurse on several subtrees (two in our case) and we want to make it in one call, we have to pass the other branch of the tree into the accumulator. So the accumulator itself becomes a list of trees that we haven't yet visited. The general name for this approach is the "worklist algorithm" and the key idea - is that we do as much work as possible and push the rest into the worklist, e.g.,

    let max_elt t =
      let rec loop work t =
        match t with
        | Tree (l,r) -> loop (l::work) r
        | Leaf x -> match work with
          | [] -> x
          | [Leaf x'] -> max x x'
          | Tree _ as t' :: ts -> loop (t::ts) t'
          | Leaf x' :: t :: ts -> loop (Leaf (max x x') :: ts) t in
      loop [] t
    

    Whoops, that looks complicated, much more complicated than the original stack-consuming version. Was it worthwhile, i.e., did we significantly improved our performance? For a balanced tree, no, we are on par - the stack-based version was consuming O(depth(t)) stack slots, where depth is the depth of the tree. Since the size, N, of a balanced binary tree (the number of nodes) is 2^depth(t) we essentially were consuming O(log(N)), which is good. Our tail-recursive version is also consuming the same amount of heap. And for a balanced tree, we should not be afraid of consuming the stack memory, since we will run out of the heap memory before that (again, to store a tree with depth N we need 2^N elements, even if the stack is limited with 64 slots, we will be able to process trees as larger as 2^64, which much more than any computer can hold). This is the reason why for the balanced trees we do not need to worry about the tail recursion and safely use regular recursion, which is more readable and efficient.

    But what if the tree is not balanced? I.e., when we have subtrees only on the left side and leaves on the right side. In that extreme case, we have the number of elements equal to the depth of the list. Unfortunately, unless we specifically balance our trees, which is another interesting topic, we will end up in such trees very often, i.e., try to write an of_list function that will generate a tree from a list, bets are high that you end up in an unbalanced list.

    For unbalanced trees, we can easily get a stack overflow. But this function above, it is so hard to understand, and so hard to prove that it terminates.

    Maybe there is a way to make it tail-recursive and understandable? The answer is yes, so read on.

    Recursion on trees (no brains involved)

    Okay, the next trick is a little bit hard to grasp, because it involves continuations. If you stop your brain from trying to understand the concept, it will be a no-brainer exercise, a mere technical substitution. But we are not looking for the easy path so let's make our brains do some work.

    The key idea is still the same, we need to call the recursive function once and pass to it some value that is necessary for the recursion to finish the work. In our previous example, we reified "the work to finish" as a list of nodes that are not yet processed. But what if we will represent it as a function, that will receive the intermediate result and continue the work, i.e., as a continuation. The convention is to call the continuation k, but we will call it return,

    let max_elt t =
      let rec loop t return =
        match t with
        | Leaf x -> return x
        | Tree (l,r) ->
          loop l @@ fun x ->
          loop r @@ fun y ->
          return (max x y) in
      loop t (fun x -> x)
    

    Okay, first of all, what is @@ and aren't we calling loop twice? This is an application operator, which is defined as let (@@) f x = f x, so if we will rewrite our loop function, we can see that we indeed call it once per call,

    let rec loop t return =
        match t with
        | Leaf x -> return x
        | Tree (l,r) ->
          loop l (fun x ->
              loop r (fun y ->
                  return (max x y)))
    

    So, slowly, in the Tree (l,r) case we call loop l <cont1> and pass to it a continuation (a function) that receives the maximum value of the subtree l and then calls loop r <cont2>, which receives the maximum value from r, combines them to find the maximum of two and then uses the continuation return to send it back (upwards).

    Still not completely sure that is a tail recursion? Let's try to rewrite it even more verbose,

      let rec loop t return =
        match t with
        | Leaf x -> return x
        | Tree (l,r) ->
          let cont1 x =
            let cont2 y = return (max x y) in
            loop r cont2 in
          loop l cont1
    

    As you can see all calls to the loop are in the tail position. But how it works?

    The big idea here is that each continuation captures some information, i.e., in each we have some free variable that is captured from the outer scope, e.g., cont1 captures the right tree r, and cont2 captures the upper-level continuation return, so that in the end we have a linked list of continuations. So no free cheese, we still use O(N) memory slots, but since continuations are stored in the heap memory we are no longer wasting our valuable stack.

    Okay, now the final step, how we can apply this technique without overstressing our brains, i.e., purely syntactically? For that we will use the new feature of OCaml, called binding operators, which allow us to define our own letXX operators, e.g.,

    let (let@) = (@@)
    

    so that f (fun x -> <body>) can be expressed as let@ x = f in <body> and using this operator, we can express our tail-recurive version nearly the same as the non-tail-recursive one,

    let max_elt t =
      let rec loop t return =
        match t with
        | Leaf x -> return x
        | Tree (l,r) ->
          let@ x = loop l in
          let@ y = loop r in
          return (max x y) in
      loop t (fun x -> x)
    

    Compare it with the non-tail-recursive version,

    let rec max_elt t =
      match t with
      | Leaf x -> x
      | Tree (l,r) ->
        let x = max_elt l in
        let y = max_elt r in
        max x y
    

    So we can build a simple syntactic rule,

    1. append extra function parameter to the argument list
    2. all recursive calls shall be bound with let@
    3. the returned value shall be passed via return, unless we want to escape from the recursion earlier.

    Early escape or short-circuiting

    What if we won't use return to pass the value upwards in the recursion? In that case, our result immediately becomes the result of the whole top-level call, so we can short-circuit our search when an element is found and do not check other elemts. For example, here is how we can test our tree for element membership,

    let has t needle =
      let rec loop t return = match t with
        | Leaf x -> x = needle || return ()
        | Tree (l,r) ->
          let@ () = loop l in
          loop r return in
      loop t (fun () -> false)
    

    You can see, that in let@ () = loop l we don't even bother to return true or false from the subtree search. We use the mere fact that the function returned to us as the evidence that the element is not present in the left subtree so we need to check the right one.

    Continuations is a very powerful feature of functional languages and you can implement wonderful things with them, like variadic functions, non-deterministic computations, easily express backtracking, and so on. But this is a different topic, which I hope we will eventually explore.


    1) whether it is scarce or not depends on the architecture, the operating system, and its configuration. On modern Linux, you can make the stack size unlimited and don't worry about stack overflows at all by setting ulimit -s unlimited. On other operating systems, there is a hard limit for the maximum stack size.