Search code examples
pythoninvariantscorrectnessloop-invariant

Optimal placement of assert statements to assure correctness using invariant


I'm trying to understand invariants in programming via real examples written in Python. I'm confused about where to place assert statements to check for invariants.

My research has shown different patterns for where to check for invariants. For examples:

before the loop starts
before each iteration of the loop
after the loop terminates

vs

... // the Loop Invariant must be true here
while ( TEST CONDITION ) {
// top of the loop
...
// bottom of the loop
// the Loop Invariant must be true here
}
// Termination + Loop Invariant = Goal

Below I have put code for an invariant example from a Maths book. There are two version, one using a function and one not. I expect it makes no difference, but I want to be thorough.

My questions are:

  • what is the minimum number of assert statemnts I need to assure program correctness, in keeping with the invariant?
  • which of the assert statments in my examples are redundant?
  • If there are multiple answers to the above question, which would be considered best practice?

Ideally I'd like to see a rewriting of my code to include best pratices and attention to any issues I may have overlooked in my work so far.

Any input much appreciated.

Here's the exercise:

E2. Suppose the positive integer n is odd. First Al writes the numbers 1, 2,..., 2n on the blackboard. Then he picks any two numbers a, b, erases them, and writes, instead, |a − b|. Prove that an odd number will remain at the end.

Solution. Suppose S is the sum of all the numbers still on the blackboard. Initially this sum is S = 1+2+···+2n = n(2n+1), an odd number. Each step reduces S by 2 min(a, b), which is an even number. So the parity of S is an invariant. During the whole reduction process we have S ≡ 1 mod 2. Initially the parity is odd. So, it will also be odd at the end.

import random

def invariant_example(n):
    xs = [x for x in range(1, 2*n+1)]
    print(xs)
    assert sum(xs) % 2 == 1
    while len(xs) >= 2:
        assert sum(xs) % 2 == 1
        a, b = random.sample(xs, 2)
        print(f"a: {a}, b: {b}, xs: {xs}")
        xs.remove(a)
        xs.remove(b)
        xs.append(abs(a - b))
        assert sum(xs) % 2 == 1
    assert sum(xs) % 2 == 1
    return xs
    
print(invariant_example(5))

n = 5
xs = [x for x in range(1, 2*n+1)]
print(xs)
assert sum(xs) % 2 == 1
while len(xs) >= 2:
    assert sum(xs) % 2 == 1
    a, b = random.sample(xs, 2)
    print(f"a: {a}, b: {b}, xs: {xs}")
    xs.remove(a)
    xs.remove(b)
    xs.append(abs(a - b))
    assert sum(xs) % 2 == 1
assert sum(xs) % 2 == 1
print(xs)

Solution

  • The only technically redundant assert statement you have is either of the ones in the loop. As in, you don't really need both of them.

    For example:

    If you have both of them, the first assert in the while loop will execute immediately after the second (as the code will return to the top of the loop). No values change in between those calls, so the first assert statement will always have the same result as the second.

    Best practice would probably be to keep the assert at the top of the loop, to prevent code within the loop from executing if the loop invariant is violated.

    EDIT: The final assert statement should also include the loop exit condition, as Kelly Bundy noted. I forgot to mention this above.