Search code examples
pythonz3smtz3pytheorem-proving

(Z3Py) Using all_smt to generate all solutions of a model


The online Programming Z3 book contains a full implementation of an all_smt function that generate an iterator of all valid solutions for a model. I am using this all_smt function verbatim, as well as a naive implementation all_smt_slow_method that performs the same thing. My all_smt_slow_method function produces the correct answer, while all_smt produces an incorrect answer. I suspect that the available all_smt implementation is correct, and that I'm calling it wrong somehow.

The problem:

1. x, y, z are integers between 0 and 11, inclusive.

2. The difference between x and y must be less than or equal to 8

3. The difference between y and z must be less than or equal to 8

How many valid (x, y, z) pairs satisfy the above constraints? 

My solution:

from z3 import *


def all_smt(s, initial_terms):
    def block_term(s, m, t):
        s.add(t != m.eval(t, model_completion=True))

    def fix_term(s, m, t):
        s.add(t == m.eval(t, model_completion=True))

    def all_smt_rec(terms):
        if sat == s.check():
            m = s.model()
            yield m
            for i in range(len(terms)):
                s.push()
                block_term(s, m, terms[i])
                for j in range(i):
                    fix_term(s, m, terms[j])
                yield from all_smt_rec(terms[i:])
                s.pop()

    yield from all_smt_rec(list(initial_terms))


def all_smt_slow_method(s, initial_terms, x, y, z):
    s.add(initial_terms)
    manual_counter = 0
    while s.check() == sat:
        s.add(Or(x != s.model()[x], y != s.model()[y], z != s.model()[z]))
        manual_counter += 1
    return manual_counter


def main():
    maximum_delta = 8
    range = 12
    x, y, z = Ints("x y z")
    initial_terms = [
        0 <= x,
        x < range,
        0 <= y,
        y < range,
        0 <= z,
        z < range,
    ]
    initial_terms += [
        x - y <= maximum_delta,
        y - x <= maximum_delta,
        y - z <= maximum_delta,
        z - y <= maximum_delta,
    ]

    manual_counter = all_smt_slow_method(Solver(), initial_terms, x, y, z)

    all_smt_generator = all_smt(Solver(), initial_terms)
    all_smt_counter = sum(1 for x in all_smt_generator)

    print(f"{range ** 3 = } {manual_counter = } {all_smt_counter = }")


if __name__ == "__main__":
    main()

manual_counter evaluates to 1468, which I believe be the correct answer. The result of all_smt has a size of 111, which is incorrect.


Solution

  • Your use of the all_smt method isn't quite right. Use it like this instead:

    s = Solver()
    s.add(initial_terms)
    all_smt_generator = all_smt(s, [x, y, z])
    

    That is, you should first create a solver, add the constraints, and then pass it to all_smt as the first argument. The second argument is the list of variables. (In general they can be arbitrary expressions, but for all-smt, you should pass the set of variables you care about. The way you called it, you essentially counted how many different assignments can be made so your constraints themselves can take different values. It'd be interesting to check that number is indeed 111, though I do trust z3 that should indeed be the case.)

    If you modify your program like this, you'll see both methods produce the correct count of 1468 solutions.

    Side note Even though what you wrote is equivalent, using Abs would make your solution clearer. That is, instead of writing [x - y <= maximum_delta, y - x <= maximum_delta], you can shorten it to Abs(x - y) <= maximum_delta. And similarly for y and z.

    There's no built-in method to indicate a constraint of the form a <= b < c. But you can create one yourself. Just define:

    # CO: Closed on the left, Open on the right range
    def inCORange (left, x, right):
       return And(left <= x, x < right)
    

    and use it as you wish. This is in fact how you use z3py most of the time, by writing these helper functions that construct z3 constraints for you.