Search code examples
binary-searchz3pyproof-of-correctness

Trying to Prove binary search termination with Z3, but Z3 finds an invalid counter example


I am trying to learn and experiment with z3 attempting to prove binary search. First step is to ask if the function even terminates. This should be achievable by proving that the size function (r - l) is always decreasing. Here is my code

from z3 import *


# prove that binary search always terminates
# this can be proven by the fact that r - l
# is always decreasing regardless of the
# executed branch in the loop

l, r = Ints("l r")

preconditions = And([
    l >= 0,
    r > l, # assume the range is not empty
])

mid = Int("mid")
termination = And([
    mid == (l + r) / 2,
    (r - l) > (mid - l),
    (r - l) > (r - mid + 1),
])

claim = Implies(preconditions, termination)

prove(claim)

please not that I am assuming the search range is inclusive from the left and exclusive from the right

but when I execute the code z3 reports that a counter example was found

counterexample
[l = 0, mid = 1, r = 1]

This doesn't make sense to me as the first condition in the termination states that mid == (l + r) / 2.

I tried inlining the mid in the two conditions but still got the same result

termination = And([
    (r - l) > ((l + r) / 2 - l),
    (r - l) > (r - (l + r) / 2 + 1),
])
counterexample
[l = 0, r = 1]

I checked that z3 does not cast the division to Real by executing the below code in the interpreter

>>> from z3 import *
>>> x, y = Ints("x y")
>>> ((x + y) / 2).sort()
Int

I also tried creating a constraint with the counter example z3 found, and check that it is compatible with my definition of mid but, z3 reported that for the below no solution found

solve(mid == 1, mid == (l + r) / 2, l == 0, r == 1)

I am confused why z3 thinks that the counter example it found is valid?

expected z3 to prove that my claim is True


Solution

  • The model z3 is giving you is just fine. The counterexample is:

    [l = 0, mid = 1, r = 1]
    

    Let's see if it satisfies your preconditions:

    preconditions = And([
        l >= 0,
        r > l, # assume the range is not empty
    ])
    

    l >= 0 is true, since 0 >= 0, and r > l is true since, 1 > 0.

    Since you have an implication, the counter-example would be justified if your "termination" is not satisfied by this model. Let's see:

    termination = And([
        mid == (l + r) / 2,
        (r - l) > (mid - l),
        (r - l) > (r - mid + 1),
    ])
    

    mid == (l+r) / 2 fails, since 1 != (0 + 1) / 2. (Right hand side is 0 by SMTLib division semantics.) We don't have to look at the remaining conjuncts. So, with this model your "preconditions" do not imply "termination." This is why you're getting a counter-example.

    How to fix

    To make things concrete, let's first see what the algorithm actually looks like. You haven't given us the "code" you're modeling, but I assume it looks something like:

    int l = 0, r = (int)vec.size();
    while(l != r) {
        int mid = (l + r) / 2;
        if (vec[mid] < val) l = mid + 1;
        else r = mid;
    }
    

    The actual details, language, data-types used etc. are not important. What's important is the control flow. The loop terminates when l == r. So, we should show that in any control path, the distance between l and r reduces. In other words the metric r - l is always non-negative and goes down in each iteration. So let's code that:

    from z3 import *
    
    
    l, r = Ints("l r")
    
    preconditions = And([l >= 0, r > l])
    
    mid = (l+r) / 2
    
    # Starting metric
    startingMetric = r - l
    
    # if vec[mid] < val, then l = mid + 1
    metric1 = r - (mid + 1)
    
    # if vec[mid] >= val, then r = mid
    metric2 = mid - l
    
    termination = And(
         And(metric1 >= 0, metric1 < startingMetric)
       , And(metric2 >= 0, metric2 < startingMetric)
       )
    
    claim = Implies(preconditions, termination)
    
    prove(claim)
    

    And if you run this, you'll get proved.

    This establishes the fact that starting with l >= 0 && r > l, whichever control-path you take the metric r - l reduces, but remains non-negative; thus ensuring termination.

    Aside

    I should add that a "theoretical" bug in binary search is that the value mid = (l+r) / 2 doesn't always compute the mid-point when machine-integers are used. Specifically, if l and r are mathematical integers as modeled by z3's Int type, then all is well. But if they are 32-bit or 64-bit machine integers, then l+r can overflow, and the calculation of the mid-point will be incorrect. This is usually not a big deal in practice, because it'll happen when the input vector is really really large; possibly not fitting in your memory anyhow. But to be absolutely safe, you should compute the mid-point in a safe way. You can read more about this gotcha here: https://blog.research.google/2006/06/extra-extra-read-all-about-it-nearly.html. You can observe this yourself trying the above program with a change in the declarations: Use l, r = BitVecs("l r", 32) and look at the counter-example produced.