Search code examples
pythondynamic-programmingz3z3py

Z3 optimize by index not a value


With greate respect to the answer of @alias there: (Find minimum sum) I would like to solve similar puzzle. Having 4 agents and 4 type of works. Each agent does work on some price (see initial matrix in the code). I need find the optimal allocation of agents to the particular work. Following code almost copy paste from the mentioned answer:

initial = (  # Row - agent, Column - work
    (7, 7, 3, 6),
    (4, 9, 5, 4),
    (5, 5, 4, 5),
    (6, 4, 7, 2)
)

opt = Optimize()    

agent = [Int(f"a_{i}") for i, _ in enumerate(initial)]
opt.add(And(*(a != b for a, b in itertools.combinations(agent, 2))))

for w, row in zip(agent, initial):
    opt.add(Or(*[w == val for val in row]))

minTotal = Int("minTotal")
opt.add(minTotal == sum(agent))
opt.minimize(minTotal)
print(opt.check())
print(opt.model())

Mathematically correct answer: [a_2 = 4, a_1 = 5, a_3 = 2, a_0 = 3, minTotal = 14] is not working for me, because I need get index of agent instead. Now, my question - how to rework the code to optimize by indexes instead of values? I've tried to leverage the Array but have no idea how to minimize multiple sums.


Solution

  • You can simply keep track of the indexes and walk each row to pick the corresponding element. Note that the itertools.combinations can be replaced by Distinct. We also add extra checks to make sure the indices are between 1 and 4 to ensure there's no out-of-bounds access:

    from z3 import *
    
    initial = (  # Row - agent, Column - work
        (7, 7, 3, 6),
        (4, 9, 5, 4),
        (5, 5, 4, 5),
        (6, 4, 7, 2)
    )
    
    opt = Optimize()
    
    def choose(i, vs):
        if vs:
            return If(i == 1, vs[0], choose(i-1, vs[1:]))
        else:
            return 0
    
    agent = [Int(f"a_{i}") for i, _ in enumerate(initial)]
    opt.add(Distinct(*agent))
    for a, row in zip(agent, initial):
        opt.add(a >= 1)
        opt.add(a <= 4)
        opt.add(Or(*[choose(a, row) == val for val in row]))
    
    minTotal = Int("minTotal")
    opt.add(minTotal == sum(choose(a, row) for a, row in zip (agent, initial)))
    opt.minimize(minTotal)
    print(opt.check())
    print(opt.model())
    

    This prints:

    sat
    [a_1 = 1, a_0 = 3, a_2 = 2, a_3 = 4, minTotal = 14]
    

    which I believe is what you're looking for.

    Note that z3 also supports arrays, which you can use for this problem. However, in SMTLib, arrays are not "bounded" like in programming languages. They're indexed by all elements of their domain type. So, you won't get much of a benefit from doing that, and the above formulation seems to be the most straightforward.