Search code examples
pythonz3smtsat

Z3 Solver::check() gets stuck on solvable input


In a day 17 in AOC2024, there was a problem with a VM in which one is supposed to brute force the input that yields a certain output, by recognizing that the input is read in 3 bit chunks (hence the //= 8). While I later got the solution using brute force, I would like to use a different solution, as I saw someone use a SAT solver. Everybodys program is of course different so I can't just look up how one would use a SAT solver, but I wanted to try and learn how to use Z3 anyways.

I should note that I have never used Z3 before, so if I'm doing something very stupid, then I'm sorry, but would like to learn the correct way of doing it.

from z3 import IntVal, Int2BV, BV2Int, Solver, Int


def program(a):
    b = 0
    c = 0
    output = []

    while a > 0:
        b = a % 8
        b = b ^ 6
        denominator = 2**b
        c = a // denominator
        b = b ^ c
        b = b ^ 4
        output.append(b % 8)
        a //= 8

    return output


print(program(123))  # [ 2, 2, 3 ]
output = [2, 2, 3]

a = Int('a')

s = Solver()
s.add(a > 0)

# constants
eight = IntVal(8)

six = IntVal(6)
six_bv = Int2BV(six, 32)

four = IntVal(4)
four_bv = Int2BV(four, 32)

a_temp = a
for x in output:
    # b = a % 8
    s1 = a_temp % eight
    s1_bv = Int2BV(s1, 32)

    # b = b ^ 6
    s2_bv = s1_bv ^ six_bv
    s2 = BV2Int(s2_bv)

    # c = a / (2 ** b)
    s3_denom = IntVal(2) ** s2
    s3 = a_temp / s3_denom
    s3_bv = Int2BV(s3, 32)

    # b = b ^ c
    s4_bv = s2_bv ^ s3_bv

    # b = b ^ 4
    s5_bv = s4_bv ^ four_bv
    s5 = BV2Int(s5_bv)

    # out(b % 8)
    character = s5 % eight

    print(character, "==", x)
    s.add(character == x)
    a_temp = a_temp / 8

s.add(a_temp == 0)

print(s.check())
print(s.model())

The problem was originally represented in a custom ISA, so I have decompiled it and rewritten it to python, that is the program function that I defined at the top. Putting in 123 yields [2, 2, 3]. Now I tried to give Z3 the same algorithm by enforcing each number to be the same as the output.

However Z3 just hangs when calling s.check(), when providing a smaller input, then I sometimes get unknown, I can't get it to output sat and give me the number I once used to generate the output.

I already tried to run with more/less equations (for example I later added a > 0 and also added that a_temp ending up with being == 0, so that the loop can be broken).

I also thought about if Z3 is handling integer division properly, but my testing around seems to show that Z3 is indeed doing integer division by default, so it should already be correct.


Solution

  • Mixing and matching integers/bit-vectors is typically going to cause performance problems for SMT-solvers, as conversions between these create non-linear constraints.

    Non-linear arithmetic is undecidable. And the algorithms used for such problems are not particularly efficient. So, you might get unknown, or the solver might loop-forever, unless some heuristic hits and produces a solution. So, mixing and matching of different types of numbers should be avoided when doing arithmetic, if you can help it.

    I haven't studied the original AOC problem, but I assume it allows the a/b/c values to be arbitrary integers. So, you'd model this with z3's Int indeed. But you keep converting them back and forth between 32-bit bit vectors, so in a sense they are not processed as unbounded values anyhow. So, I'd make that explicit: Forget about unbounded integers, and simply limit a/b/c to be 32-bit machine integers instead. (Whether this is acceptable is of course problem dependent.)

    Based on this, I'd code your problem as follows:

    from z3 import *
    
    output = [2, 2, 3]
    
    s = Solver()
    
    a, b, c = BitVecs('a b c', 32)
    s.add(a > 0)
    
    for x in output:
       b = a % 8
       b = b ^ 6
       denominator = 1 << b
       c = a / denominator
       b = b ^ c
       b = b ^ 4
       s.add (x == b % 8)
       a /= 8
    
    s.add(a <= 0)
    
    print(s.check())
    print(s.model())
    

    When I run this, I get:

    sat
    [a = 115]
    

    So, z3 is suggesting the value of a should be 115. From your question, I gather that you wanted this value to be actually 123. Well, it turns out that different starting states can produce the same output. Indeed:

    def program(a):
        b = 0
        c = 0
        output = []
    
        while a > 0:
            b = a % 8
            b = b ^ 6
            denominator = 2**b
            c = a // denominator
            b = b ^ c
            b = b ^ 4
            output.append(b % 8)
            a //= 8
    
        return output
    
    print(program(123))
    print(program(115))
    

    prints:

    [2, 2, 3]
    [2, 2, 3]
    

    So, looks like z3 successfully reverse-engineered your problem. It just found a different starting state that happens to produce the same output.

    Long story short, SMT solvers are indeed good for these kinds of reverse-engineering problems, but there are limitations. Try to stick to bit-vectors as much as possible, and in particular do not mix-and-match different kinds of numbers if you can help it.