Search code examples
sqlrecursioncommon-table-expression

Recursive function in procedural-code vs SQL


I am wondering if the following translation from a recursive code example (in Python here, but the language doesn't matter) is approximately correct in SQL:

def countdown(num):
    if num > 0:
        return countdown(num-1)
    else:
        return num

Now, this doesn't have any 'output', so to make it similar to SQL rows, I'm returning an array for each 'result':

res = []
def countdown(num):
    res.append(num)
    if num > 0:
        return countdown(num-1)
    else:
        return res

res = countdown(4)
# [4, 3, 2, 1, 0]

And here is my attempt in SQL:

WITH RECURSIVE countdown AS 
(
    SELECT 4 AS num
    UNION ALL
    SELECT num-1 FROM countdown WHERE num > 0
)
SELECT num 
FROM countdown;

 num 
-----
   4
   3
   2
   1
   0

Is this a accurate translation? A few things that seem 'weird' to me with the SQL implementation:

  • The 'parameter' (n=4) here needs to be hardcoded inside the CTE. Is there any other way around this so it reads more like a 'normal recursive function' ? I'm almost expecting something like SELECT num FROM countdown(4).
  • The base condition starts things off, whereas the base condition in a procedural function 'ends things'. Is that a correct understanding?
  • The WHERE condition essentially encodes the base condition (num > 0).

Is that a correct understanding here?


Solution

  • I don't know whether this helps but here is a (reasonable) procedural interpretation of what the recursive SQL might do:

    def extend_cur_set_to_new_set(cur_set):
        cur_values = list(cur_set)
        new_values = [ (i - 1) for i in cur_values if i > 0 ]
        return set(new_values + cur_values)
    
    def countdown_find_fixpoint(cur_set):
        print(f"countdown_find_fixpoint({cur_set})")
        new_set = extend_cur_set_to_new_set(cur_set)
        print(f"{cur_set} has been extended to {new_set}")
        if cur_set == new_set:
            print("Fixpoint reached")
            return cur_set
        else:
            return countdown_find_fixpoint(cur_set | new_set)
    
    print(countdown_find_fixpoint({4}))
    

    When the above runs:

    countdown_find_fixpoint({4})
    {4} has been extended to {3, 4}
    countdown_find_fixpoint({3, 4})
    {3, 4} has been extended to {2, 3, 4}
    countdown_find_fixpoint({2, 3, 4})
    {2, 3, 4} has been extended to {1, 2, 3, 4}
    countdown_find_fixpoint({1, 2, 3, 4})
    {1, 2, 3, 4} has been extended to {0, 1, 2, 3, 4}
    countdown_find_fixpoint({0, 1, 2, 3, 4})
    {0, 1, 2, 3, 4} has been extended to {0, 1, 2, 3, 4}
    Fixpoint reached
    {0, 1, 2, 3, 4}
    

    The 4 is not a parameter, it is the initial set {4} that will be grown until the fixpoint (i.e. the set that, if grown, yields itself) has been obtained. And at that point we are done.

    This is quite different from growing an initially empty list (or set) until a stop condition has been reached, which is the point when all the desired elements are in the list (or set). This can be written using either loop or a recursive descent (tail-recursive or not) depending on taste. In Python one would rather use a loop as as the list holding the result is mutable and one can insert() into it and append() to it freely.

    P.S.

    In languages that have immutable data structures (like Prolog or Lisp/Scheme) one would write recursive code like so (here, properly tail-recursive if one removes the print() instructions - thus not growing the stack, as for a loop, if the tail call is optimized away, which Python does not do in any case):

    def countdown_tailrec_help(cur,list):
        if cur >= 0:
            newlist = list.copy()
            newlist.append(cur)
            print(f"Calling countdown_tailrec_helper({cur-1}, {newlist})")
            res = countdown_tailrec_help(cur-1, newlist)
        else:
            res = list
        print(f"countdown_tailrec_help({cur},{list}) returns {res}")
        return res
    
    def countdown_tailrec(num):
        print(f"Calling countdown_tailrec_helper({num}, [])")
        return countdown_tailrec_help(num,[])
    
    print(countdown_tailrec(4))
    

    Then:

    Calling countdown_tailrec_helper(4, [])
    Calling countdown_tailrec_helper(3, [4])
    Calling countdown_tailrec_helper(2, [4, 3])
    Calling countdown_tailrec_helper(1, [4, 3, 2])
    Calling countdown_tailrec_helper(0, [4, 3, 2, 1])
    Calling countdown_tailrec_helper(-1, [4, 3, 2, 1, 0])
    countdown_tailrec_help(-1,[4, 3, 2, 1, 0]) returns [4, 3, 2, 1, 0]
    countdown_tailrec_help(0,[4, 3, 2, 1]) returns [4, 3, 2, 1, 0]
    countdown_tailrec_help(1,[4, 3, 2]) returns [4, 3, 2, 1, 0]
    countdown_tailrec_help(2,[4, 3]) returns [4, 3, 2, 1, 0]
    countdown_tailrec_help(3,[4]) returns [4, 3, 2, 1, 0]
    countdown_tailrec_help(4,[]) returns [4, 3, 2, 1, 0]
    [4, 3, 2, 1, 0]
    

    P.P.S.

    The non-tail-recursive version:

    def countdown(num):
        if num >= 0:
            print(f"Calling countdown({num-1})")
            list = countdown(num-1) # not tail recursive
            newlist = list.copy()
            newlist.insert(0,num)
            res = newlist
        else:
            res = []
        print(f"countdown({num}) returns {res}")
        return res
    
    print(countdown(4))
    

    And so:

    Calling countdown(3)
    Calling countdown(2)
    Calling countdown(1)
    Calling countdown(0)
    Calling countdown(-1)
    countdown(-1) returns []
    countdown(0) returns [0]
    countdown(1) returns [1, 0]
    countdown(2) returns [2, 1, 0]
    countdown(3) returns [3, 2, 1, 0]
    countdown(4) returns [4, 3, 2, 1, 0]
    [4, 3, 2, 1, 0]