Search code examples
pythonalgorithmrecursiondynamic-programmingdivide-and-conquer

Transforming a divide and conquer recursive algorithm into an iterative version


I would like to transform a recursive algorithm on an array into an iterative function. It is not a tail recursive algorithm and has two recursive calls followed by some operation. The algorithm is a divide-and-conquer algorithm where at each step the array is split into two subarrays and some function f is applied to the two previous outcomes. In practice f is complicated, so the iterative algorithm should use the function f, for a minimal working example I have used a simple addition.

Below is a minimal working example of the recursive program in python.

import numpy as np

def f(left,right):
    #In practice some complicated function of left and right
    value=left+right
    return value

def recursive(w,i,j):
    if i==j:
        #termination condition when the subarray has size 1
        return w[i]
    else:
        k=(j-i)//2+i
        #split the array into two subarrays between indices i,k and k+1,j
        left=recursive(w,i,k)
        right=recursive(w,k+1,j)

        return f(left,right)

a=np.random.rand(10)
print(recursive(a,0,a.shape[0]-1))

Now if I want to write this iteratively I realize that I need a stack to store intermediate results, and that at each step I need to apply f to the two elements on the top of the stack. I am just not sure how to construct the order in which I put elements in the stack without recursion. Here is an attempt at a solution which is certainly not optimal since it seems there should be a way to remove the first loop and use only one stack:

def iterative(w):
    stack=[]
    stack2=[]
    stack3=[]
    i=0
    j=w.shape[0]-1
    stack.append((i,j))
    while (i,j)!=(w.shape[0]-1,w.shape[0]-1):
        (i,j)=stack.pop()
        stack2.append((i,j))
        if i==j:
            pass
        else:
            k=int(np.floor((j-i)/2)+i)
            stack.append((k+1,j))
            stack.append((i,k))
    while len(stack2)>0:
        (i,j)=stack2.pop()
        if i==j:
            stack3.append(w[i])
        else:
            right=stack3.pop()
            left=stack3.pop()
            stack3.append(f(left,right))
    return stack3.pop()

Edit : The real problem I am interested in has as input an array of tensors of different sizes, and the operation f solves a linear program involving these tensors and outputs a new tensor. I cannot iterate simply over the initial array since the size of the output of f grows exponentially in this case. This is why I use this divide and conquer approach, which reduces this size. The recursive program works fine, but slows down dramatically for large size, possibly due to the frames that python opens and keeps track of.


Solution

  • Below I transformed the program to use a continuation (then) and a trampoline (run/recur). It evolves a linear iterative process and it will not overflow the stack. If you're not running into a stack overflow issue, this won't do much to help your specific problem, but it can teach you how to flatten branching computations.

    This process of converting a normal function to continuation passing style can be a mechanical one. If you squint your eyes a little bit, you'll see how the program has most of the same elements as yours. Inline comments show code side-by-side -

    import numpy as np
    
    def identity (x):
      return x
    
    def recur (*values):
      return (recur, values)
    
    def run (f):
      acc = f ()
      while type (acc) is tuple and acc [0] is recur:
        acc = f (*acc [1])
      return acc
    
    def myfunc (a):
      # def recursive(w,i,j)
      def loop (w = a, i = 0, j = len(a)-1, then = identity):
        if i == j:                # same
          return then (w[i])      # wrap in `then`
        else:                     # same
          k = (j - i) // 2 + i    # same
          return recur \          # left=recursive(w,i,k)
            ( w
            , i
            , k
            , lambda left:
              recur               # right=recursive(w,k+1,j)
                ( w
                , k + 1
                , j
                , lambda right:
                    then          # wrap in `then`
                      (f (left, right)) # same
                )
            )
      return run (loop)
    
    def f (a, b):
        return a + b              # same
    
    a = np.random.rand(10)        # same
    print(a, myfunc(a))           # recursive(a, 0, a.shape[0]-1)
    
    # [0.5732646  0.88264091 0.37519826 0.3530782  0.83281033 0.50063843 0.59621896 0.50165139 0.05551734 0.53719382]
    
    # 5.208212213881435