Search code examples
pythonloopsrecursionreturnbinary-tree

How to pass on values when converting a recursive function to an iterative one with stack and while loop


EDIT: My initial recursive code was incomplete, as I did not multiply by f(l, r) the sum of the two recursive branches.

Given a function f(l, r) that computes something from the nodes (l, r) of a binary tree of height max_height, I want to compute a total value by adding values of the left and right node children and multiply the sum with the value of the parent node and pass that along the tree.

I have a working recursive implementation of this, but now I want to eliminate the recursion with a while loop and stack structures. My problem is that I don't know how to "pass on" values during the while loop. That is, I don't know how I should replicate the behavior when I multiply the current value f(l, r) with the sum from the two recursion branches.

I have included two code snippets: The first is the current recursive implementation and the second one is my attempt at the more iterative based approach. The latter code needs more work, and I have placed commented TODOs to indicate some of my questions.

def recursive_travel(l, r, cur_height, max_height):
    if cur_height == max_height - 1:
       return f(l, r) * (f(l + 1, r) + f(l, r + 1))
    return f(l, r)* (recursive_travel(l + 1, r, cur_height + 1, max_height) + recursive_travel(l, r + 1, cur_height + 1, max_height))

where the initial call will be recursive_travel(0, 0, 0, max_height).

Attempt the removing the recursion:

def iterative_travel(max_height):
  call_stack = [(0, 0, 0)] # cur_height, l, r in that order
  handled_stack = [] # TODO: Maybe I need to have something like this, or maybe I need a double array to store computed values?

  # Precompute the value of r_c directly to an n x n table for fast access
  pre_f = [[f(l, r) for l in range(0, max_height + 1)] for r in range(0, max_height + 1)]

  while call_stack:
    cur_height, l, r = stack.pop()
    if max_height - 1 == cur_height: 
      # TODO: Not sure how to pass on the computed values

      # TODO: Where I should put this value? In some table? In some stack?
      value = pre_f[l, r] * (pre_f[l + 1, r] + pre_f[l, r + 1])

      # TODO: Should I mark somewhere that the node (l, r) has been handled?
    elif handled_stack:
      # TODO: Not sure how to handle the computed values
      pass
    else:
      # TODO: Do I do something to the current l and r here?
      stack.append((current_depth + 1, l + 1, r))
      stack.append((current_depth + 1, l, r + 1))
  return 0 # TODO: Return the correct value

Solution

  • The job of your recursive function is to calculate a value for each leaf in the virtual, perfect binary tree, and then bubble up such that a parent node gets the sum of the two children, multiplied by its own value.

    The base of the recursion still considers one additional level to the virtual tree as it calls f with an increased argument. I would therefore rewrite the recursive function with a base case that is one level deeper:

    def recursive_travel(l, r, cur_height, max_height):
        if cur_height == max_height + 1:
            return f(l, r)
        return f(l, r) * (recursive_travel(l + 1, r, cur_height + 1, max_height)
                        + recursive_travel(l, r + 1, cur_height + 1, max_height))
    

    This better shows how each node contributes with its "own" value (determined by f). That means the virtual binary tree is actually one level higher than max_height.

    Another observation is that l and r represent the number of left and right turns in the path from the root to the node. This means that a subtree that has the same l and r numbers as another subtree, will get the exact same value. It would be a waste to recalculate that whole subtree more than once.

    Your tree is thus actually more a network that looks like this:

                    /\
                   /\/\
                  /\/\/\
                 /\/\/\/\
                /\/\/\/\/\ 
               /\/\/\/\/\/\
              ... ... ... ...
    

    This reduces the number of nodes you actually have to evaluate, making the number of calculations O(𝑛²) instead of O(2𝑛).

    We could turn this into an iterative solution in a bottom up way: first calculate the result for all the unique l-r combinations at the bottom, then for all the parents, ... until the root's value is determined, which is also the end result.

    Taking all those observations together we get this code:

    def iterative_travel(max_height):
        level = [
            f(left, max_height + 1 - left)
            for left in range(max_height + 2)
        ]
        for height in range(max_height, -1, -1):
            for left in range(height + 1):
                level[left] = f(left, height - left) * (level[left] + level[left + 1])
        return level[0]
    

    Efficiency

    There shouldn't be too much difference with your recursive approach, on the condition that you apply memoization. This can for instance be done with functools.cache.

    For instance, I invented a simple f and ran with a max height of 100:

    from functools import cache
    
    def f(l, r):
        return 2 + l * 3 + r * 5  # Some example calculation
    
    @cache
    def recursive_travel(l, r, cur_height, max_height):
        if cur_height == max_height + 1:
            return f(l, r)
        return f(l, r) * (recursive_travel(l + 1, r, cur_height + 1, max_height)
                        + recursive_travel(l, r + 1, cur_height + 1, max_height))
    
    def iterative_travel(max_height):
        level = [
            f(left, max_height + 1 - left)
            for left in range(max_height + 2)
        ]
        for height in range(max_height, -1, -1):
            for left in range(height + 1):
                level[left] = f(left, height - left) * (level[left] + level[left + 1])
        return level[0]
    
    n = 100
    result = recursive_travel(0, 0, 0, n)
    print(result)
    result2 = iterative_travel(n)
    print(result2)
    

    This script prints the result in a fraction of a second for both approaches. For larger heights, like 300, the iterative version turns out to be faster, but both still finish under 1 second.