Search code examples
javaloopsassert

Struggling to find the correct loop invariant


I have the following code:

public static void main(String[] args) {
    int a = 3;
    int b = 7;

    int x = b; // x=b
    int res = a; // res = a
    int y = 1;

    int invariant = 0;

    System.out.println("a|b|x|y|res|invariant");
    while (x > 0) { 
        if (x % 2 == 0) {
            y = 2 * y;
            x = x / 2;  
        } else {
            res = res + y;
            y = 2 * y;
            x = (x - 1) / 2;
        }
        invariant = y + 2;
        String output = String.format("%d|%d|%d|%d|%d|%d", a,b,x,y,res,invariant);
        System.out.println(output);
    }
    // < res = a + b >
}

Which gives the following output:

a|b|x|y|res|invariant
3|7|3|2|4|4
3|7|1|4|6|6
3|7|0|8|10|10

However, if I change the numbers, the invariant isn't equal to the res anymore. Therefore my loop invariant for this problem is not correct.

I'm struggling really hard to find the correct loop invariant and would be glad if there's any hint that someone can give me.

My first impression after looking into the code and my results is that the loop invariant changes based on a and b. Let's say both a and b are odd numbers as they are in my example, then my Loop invariant is correct (at least it seems like it)

Is it correct to assume a loop variant like the following?

< res = y - 2 && a % 2 != 0 && b % 2 != 0 >

I did use different numbers and it seems like anytime I change them there's a different loop invariant and I struggle to find any pattern whatsoever.

I would really appreciate if someone can give me a hint or a general idea on how to solve this.

Thanks


Solution

  • This loop computes the sum a+b. res is initialized to a. Then, in each iteration of the loop, the next bit of the binary representation of b (starting with the least significant bit) is added to res, until the loop ends and res holds a+b.

    How does it work:

    x is initialized to b. In each iteration you eliminate the least significant bit. If that bit is 0, you simply divide x by 2. If it's 1, you subtract 1 and divide by 2 (actually it would be sufficient to divide by 2, since (x-1)/2==x/2 when x is an odd int). Only when you encounter a 1 bit, you have to add it (multiplied by the correct power of 2) to the result. y Holds the correct power of 2.

    In your a=3, b=7 example, the binary representation of b is 111

    • In the first iteration, the value of res is a + 1 (binary) == a + 1 = 4
    • In the second iteration, the value of res is a + 11 (binary) == a + 3 = 6
    • In the last iteration, the value of res is a + 111 (binary) == a + 7 == 10

    You could write the invariant as:

    invariant = a + (b & (y - 1));
    

    This takes advantage of the fact the at the end of the i'th iteration (i starting from 1), y holds 2^i, so y - 1 == 2^i - 1 is a number whose binary representation is i 1 bits (i.e. 11...11 with i bits). When you & this number with b, you get the i least significant bits of b.