Search code examples
pythonrecursion

Logical error in sequence 011212201 question


I am trying to solve this following challenge:

The example sequence 011212201220200112 ... is constructed as follows:

The first element in the sequence is 0.

For each iteration, repeat the following action: take a copy of the entire current sequence, replace 0 with 1, 1 with 2, and 2 with 0, and place it at the end of the current sequence. E.g. 0 -> 01 -> 0112 -> 01121220 -> ...

Create an algorithm which determines what number is at the Nth position in the sequence (using 0-based indexing).

Input

Your program should read lines from standard input. Each line contains an integer N such that 0 <= N <= 3000000000.

Output

Print out the number which is at the Nth position in the sequence.

I saw a logic and implemented it as follows:

def find_nth_element(N):
    result = 0
    while N > 0:
        N -= 1  
        result = (result + 1) % 3 if N % 2 == 1 else result
        N //= 2
    return result

import sys
for line in sys.stdin:
    N = int(line.strip())
    print(find_nth_element(N))

However, the test cases failed with this.

For instance, when the input is 11, the expected output is 0, but my code returns 1.

I thought I had captured the logic needed. Where is my mistake?


Solution

  • The logic to divide the given input by 2 repeatedly is indeed the right one, but the mistake is in N -= 1.

    If the number to divide by 2 is odd, then an integer division involves ignoring the 1 that makes the dividend odd, but:

    1. It is not necessary when you have an integer division operator at your disposal (//), which you use
    2. Even if you would subtract that 1, it should only happen when the number is odd, and it should only happen right before the division, not before you check the parity in order to adapt the result.

    Anyhow, to fix your attempt, you just have to remove this statement:

    N -= 1
    

    Other remarks

    Some other remarks about this algorithm:

    1. Reducing code

    The algorithm is essentially adding one to the result (mod 3) for every 1 in the binary representation of N. So you could use the native bit_count method to count them all:

    def find_nth_element(N):
        return N.bit_count() % 3
    

    2. Testing your solution

    To verify your solution yourself -- without having to rely on some online judge -- you could implement a function that generates the described sequence.

    def gen():
        lst = [0]
        i = 0
        while True:
            if i >= len(lst):
                lst += [(i + 1) % 3 for i in lst]
            yield lst[i]
            i += 1
    

    With this you can write a loop that verifies your solution, and if something is wrong, that prints the first input for which it fails:

    for n, expected in enumerate(gen()):
        res = find_nth_element(n)
        if res != expected:
            print("wrong result for n=", n)
            break
        if n > 1000:
            print("The first 1000 inputs were verified")
            break
    

    Happy coding!