Search code examples
z3z3py

What is the most efficient way of checking N-way equation equivalence in Z3?


Suppose I have a set of Z3 expressions:

exprs = [A, B, C, D, E, F]

I want to check whether any of them are equivalent and, if so, determine which. The most obvious way is just an N×N comparison (assume exprs is composed of some arbitrarily-complicated boolean expressions instead of the simple numbers in the example):

from z3 import *
exprs = [IntVal(1), IntVal(2), IntVal(3), IntVal(4), IntVal(3)]
for i in range(len(exprs) - 1):
  for j in range(i+1, len(exprs)):
    s = Solver()
    s.add(exprs[i] != exprs[j])
    if unsat == s.check():
      quit(f'{(i, j)} are equivalent')

Is this the most efficient method, or is there some way of quantifying over a set of arbitrary expressions? It would also be acceptable for this to be a two-step process where I first learn whether any of the expressions are equivalent, and then do a longer check to see which specific expressions are equivalent.


Solution

  • As with anything performance related, the answer is "it depends." Before delving into options, though, note that z3 supports Distinct, which can check whether any number of expressions are all different: https://z3prover.github.io/api/html/namespacez3py.html#a9eae89dd394c71948e36b5b01a7f3cd0

    Though of course, you've a more complicated query here. I think the following two algorithms are your options:

    Explicit pairwise checks

    Depending on your constraints, the simplest thing to do might be to call the solver multiple times, as you alluded to. To start with, use Distinct and make a call to see if its negation is satisfiable. (i.e., check if some of these expressions can be made equal.) If the answer comes unsat, you know you can't make any equal. Otherwise, go with your loop as before till you hit the pair that can be made equal to each other.

    Doing multiple checks together

    You can also solve your problem using a modified algorithm, though with more complicated constraints, and hopefully faster.

    To do so, create Nx(N-1)/2 booleans, one for each pair, which is equal to that pair not being equivalent. To illustrate, let's say you have the expressions A, B, and C. Create:

    • X0 = A != B
    • X1 = A != C
    • X2 = B != C

    Now loop:

    • Ask if X0 || X1 || X2 is satisfiable.
    • If the solver comes back unsat, then all of A, B, and C are equivalent. You're done.
    • If the solver comes back sat, then at least one of the disjuncts X0, X1 or X2 is true. Use the model the solver gives you to determine which ones are false, and continue with those until you get unsat.

    Here's a simple concrete example. Let's say the expressions are {1, 1, 2}:

    • Ask if 1 != 1 || 1 != 2 || 1 != 2 is sat.
    • It'll be sat. In the model, you'll have at least one of these disjuncts true, and it won't be the first one! In this case the last two. Drop them from your list, leaving you with 1 != 1.
    • Ask again if 1 != 1 is satisfiable. The answer will be unsat and you're done.

    In the worst case you'll make Nx(N-1)/2 calls to the solver, if it happens that none of them can be made equivalent with you eliminating one at a time. This is where the first call to Not (Distinct(A, B, C, ...)) is important; i.e., you will start knowing that some pair is equivalent; hopefully iterating faster.

    Summary

    My initial hunch is that the second algorithm above will be more performant; though it really depends on what your expressions really look like. I suggest some experimentation to find out what works the best in your particular case.

    A Python solution

    Here's the algorithm coded:

    from z3 import *
    
    exprs = [IntVal(i) for i in [1, 2, 3, 4, 3, 2, 10, 10, 1]]
    
    s = Solver()
    
    bools = []
    for i in range(len(exprs) - 1):
      for j in range(i+1, len(exprs)):
        b = Bool(f'eq_{i}_{j}')
        bools.append(b)
        s.add(b == (exprs[i] != exprs[j]))
    
    # First check if they're all distinct
    s.push()
    s.add(Not(Distinct(*exprs)))
    if(s.check()== unsat):
        quit("They're all distinct")
    s.pop()
    
    while True:
        # Be defensive, bools should not ever become empty here.
        if not bools:
            quit("This shouldn't have happened! Something is wrong.")
    
        if s.check(Or(*bools)) == unsat:
            print("Equivalent expressions:")
            for b in bools:
              print(f'   {b}')
            quit('Done')
        else:
            # Use the model to keep bools that are false:
            m = s.model()
            bools = [b for b in bools if not(m.evaluate(b, model_completion=True))]
    

    This prints:

    Equivalent expressions:
       eq_0_8
       eq_1_5
       eq_2_4
       eq_6_7
    Done
    

    which looks correct to me! Note that this should work correctly even if you have 3 (or more) items that are equivalent; of course you'll see the output one-pair at a time. So, some post-processing might be needed to clean that up, depending on the needs of the upstream algorithm.

    Note that I only tested this for a few test values; there might be corner case gotchas. Please do a more thorough test and report if there're any bugs!