Search code examples
z3smtz3py

Is it possible to solve for function operations given inputs and outputs?


I have a function that takes 2 integers as input, and outputs an integer.

I have a set of inputs and their known outputs.

Is it possible to figure out what operations the function is applying to the inputs to arrive at the output?

I'm not sure how to even begin modeling such a problem with z3. Any help would be greatly appreciated.

Example data:

f(1, 1) = 1
f(3, 7) = 28
f(3, 2) = 3
f(4, 6) = 56
f(10, 3) = 55
f(x, y) = f(y, x)

Solution

  • Yes you can. But the results are not going to be interesting by any means, unless you at least prescribe some structure on what f looks like.

    The obvious way to do this is to declare f as un interpreted-function, and then look at the model constructed for it. But this will be a rudimentary definition of f; that is, it'll satisfy your axioms in the most uninteresting way. Here's how you'd write it:

    from z3 import *
    f = Function('f', IntSort(), IntSort(), IntSort())
    
    s = Solver()
    s.add(f( 1, 1) ==  1)
    s.add(f( 3, 7) == 28)
    s.add(f( 3, 2) ==  3)
    s.add(f( 4, 6) == 56)
    s.add(f(10, 3) == 55)
    x, y = Ints('x y')
    s.add(ForAll([x, y], f(x, y) == f(y, x)))
    
    print(s.check())
    print(s.model())
    

    And here's what z3 prints:

    sat
    [f = [else ->
          If(And(Not(Var(0) == 6),
                 Not(Var(0) == 4),
                 Not(Var(0) == 10),
                 Var(0) == 3,
                 Not(Var(1) == 6),
                 Not(Var(1) == 4),
                 Var(1) == 10),
             55,
             If(And(Var(0) == 6, Not(Var(1) == 6), Var(1) == 4),
                56,
                If(And(Not(Var(0) == 6),
                       Not(Var(0) == 4),
                       Not(Var(0) == 10),
                       Not(Var(0) == 3),
                       Not(Var(0) == 1),
                       Var(0) == 2,
                       Not(Var(1) == 6),
                       Not(Var(1) == 4),
                       Not(Var(1) == 10),
                       Var(1) == 3),
                   3,
                   If(And(Not(Var(0) == 6),
                          Not(Var(0) == 4),
                          Not(Var(0) == 10),
                          Not(Var(0) == 3),
                          Not(Var(0) == 1),
                          Not(Var(0) == 2),
                          Not(Var(1) == 6),
                          Not(Var(1) == 4),
                          Not(Var(1) == 10),
                          Var(1) == 3),
                      28,
                      If(And(Not(Var(0) == 6),
                             Not(Var(0) == 4),
                             Var(0) == 10,
                             Not(Var(1) == 6),
                             Not(Var(1) == 4),
                             Not(Var(1) == 10),
                             Var(1) == 3),
                         55,
                         If(And(Not(Var(0) == 6),
                                Var(0) == 4,
                                Var(1) == 6),
                            56,
                            If(And(Not(Var(0) == 6),
                                   Not(Var(0) == 4),
                                   Not(Var(0) == 10),
                                   Var(0) == 3,
                                   Not(Var(1) == 6),
                                   Not(Var(1) == 4),
                                   Not(Var(1) == 10),
                                   Not(Var(1) == 3),
                                   Not(Var(1) == 1),
                                   Var(1) == 2),
                               3,
                               If(And(Not(Var(0) == 6),
                                      Not(Var(0) == 4),
                                      Not(Var(0) == 10),
                                      Var(0) == 3,
                                      Not(Var(1) == 6),
                                      Not(Var(1) == 4),
                                      Not(Var(1) == 10),
                                      Not(Var(1) == 3),
                                      Not(Var(1) == 1),
                                      Not(Var(1) == 2)),
                                  28,
                                  If(And(Not(Var(0) == 6),
                                         Not(Var(0) == 4),
                                         Not(Var(0) == 10),
                                         Not(Var(0) == 3),
                                         Var(0) == 1,
                                         Not(Var(1) == 6),
                                         Not(Var(1) == 4),
                                         Not(Var(1) == 10),
                                         Not(Var(1) == 3),
                                         Var(1) == 1),
                                     1,
                                     12)))))))))]]
    

    The way to read this output is to substitute x for Var(0) and y for Var(1), and take the nested-if-then-elses as your defining clauses of the definition.

    While I haven't checked the output line-by-line, I'm sure it is correct; in the sense that it satisfies your requirements perfectly. But I can hear you say "that's not what I really wanted!" And indeed, this is not what you wanted to see as a general/minimal function that satisfies the requirements. The way SMT solvers work, you'll never get a minimal answer unless you describe some skeleton for z3 to fill in.

    Note that this is an active research area: How to use (semi-)automated theorem provers, SMT solvers, etc. to "write" code for us. The general area is known as SyGus (Syntax-Guided Synthesis). If you want to learn more, start with https://sygus.org, which contains a general description of the problem and read this paper: https://sygus.org/assets/pdf/FMCAD'13_SyGuS.pdf