Search code examples
tail-recursion

Will the following piece of code be considered as tail recursion?


Is this below function is a tail recursion? And how to write it more efficiently and clearly.

Based on this question: We build a table of n rows (1-indexed). We start by writing 0 in the 1st row. Now in every subsequent row, we look at the previous row and replace each occurrence of 0 with 01, and each occurrence of 1 with 10.

For example, for n = 3, the 1st row is 0, the 2nd row is 01, and the 3rd row is 0110. Given two integer n and k, return the kth (1-indexed) symbol in the nth row of a table of n rows.

def kthGrammar(self, n: int, k: int) -> int:

    def getKvalue(k, predecessor):
        digit_place = k%2
        if predecessor == 0:
            return 1 if digit_place == 0 else 0
        elif predecessor == 1:
            return 0 if digit_place == 0 else 1
    
    
    def helper(n,k):
        if n==1:
            return 0
        prevK = int((k+1)/2)
        return getKvalue(k, helper(n-1,prevK))

    return helper(n,k)

Solution

  • Is your function currently tail recursive? No. The recursive call is then followed by a call to getValue.

    Your function can be cleaned up dramatically, however. We will begin by replacing 0 and 1 with False and True.

    def kthGrammar(n: int, k: int) -> int:
    
        def getKvalue(k : int, predecessor : bool) -> bool:
            return (k % 2 == 0) != predecessor
        
        def helper(n : int, k : int) -> bool
            if n==1:
                return False
            prevK = (k+1) // 2
            return getKvalue(k, helper(n-1,prevK))
    
        return int(helper(n,k))
    

    Let us further rewrite:

    def kthGrammar(n: int, k: int) -> int:
        
        def helper(n : int, k : int) -> bool
            if n==1:
                return False
            prevK = (k+1) // 2
            return (k % 2 == 0) != helper(n-1,prevK))
    
        return int(helper(n,k))
    

    Now, we try something rather clever. We define helper2(n : int, k : int, b : bool) = (b != helper(n, k)). How can we implement helper2 recursively?

    Clearly, if n = 1, then helper2(n, k, b) = (b != False) = b. Otherwise, we have helper2(n, k, b) = (b != helper(n, k)) = (b != ((k%2 == 0) != helper(n - 1, (k + 1) // 2)) = ((b != (k % 2 == 0)) != helper(n - 1, (k + 1) // 2)) = helper2(n - 1, (k + 1) // 2, b != (k % 2 == 0)).

    Note that I used the fact that for Booleans, a != (b != c) is identical to (a != b) != c.

    Finally, note that helper(n, k) = (False != helper(n, k) = helper2(n, k, False).

    So we define

    def kthGrammar(n: int, k: int) -> int:
        
        def helper2(n : int, k : int, b : bool) -> bool
            if n==1:
                return b
            prevK = (k+1) // 2
            return helper2(n - 1, prevK, b != (k % 2 == 0))
    
        return int(helper2(n, k, False))
    

    Now, we have a tail recursive function. Tail recursion is just another way to express iteration, so we can easily rewrite this to use a while loop as follows:

    def kthGrammar(n : int, k : int) -> int:
        b = False
        while n != 1:
            n, k, b = n - 1, (k + 1) // 2, b != (k % 2 == 0)
        return int(b)
    

    Which can again be replaced by

    def kthGrammar(n : int, k : int) -> int:
        b = False
        for _n in range(n, 1, -1):
            k, b = (k + 1) // 2, b != (k % 2 == 0)
        return int(b)
    

    Of course, there's no reason to start at n and count down to 1. So the final form is

    def kthGrammar(n : int, k : int) -> int:
        b = False
        for _n in range(1, n):
            k, b = (k + 1) // 2, b != (k % 2 == 0)
        return int(b)
    

    Note that we can actually perform one further optimisation. Once it is the case that k = 1, we see that the line

    k, b = (k + 1) // 2, b != (k % 2 == 0)

    is a no-op. So the final form is

    def kthGrammar(n : int, k : int) -> int:
        b = False
        for _n in range(1, n):
            if k == 1:
                break
            k, b = (k + 1) // 2, b != (k % 2 == 0)
        return int(b)
    

    This will be much more efficient in the case that k <= n - the runtime is O(min(n, log k)) as opposed to O(n).