Search code examples
pythonrecursiongenerator

generating list satisfying linear integer constraint in python


In python, I have a list w of integers, and I would like to efficiently generate all lists of (non-negative) integers such that their sum with weights in w is equal to a certain value a, that is

sum([i*j for i,j in zip(w,f)]) == a

ie the product of w with f is fixed by a. Here everything is int valued, and positive or zero. The brute force version would be to use itertools:

from itertools import product

a = 20
w = [1,1,3,3]
rslt = []
for f in product(range(a+1), repeat=len(w)):
   if sum([i*j for i,j in zip(w,f)]) == a:
      rslt.append(list(f))

print(rslt)

This however scales very badly with the value of a. Is there a way of making generating rslt efficiently? I couldn't find a way to impose constraint with itertools directly. Mathematically, starting with a known vector w, I want to find all vectors f such that their dot product is fixed by a.

I came up with the idea of a recursive function that compute the difference from a at a given stage of the loop, but I'm having trouble with the implementation. If it's possible to have an implementation that also does arbitrary constraints (e.g. quadratic), that would be fantastic.


Solution

  • Since this is flaged "recursion" I guess this solution is acceptable

    def bb(a, w):
        # If target is 0, then all remaining numbers should be 0
        if a==0: return [(0,)*len(w)]
        # If target is non 0, then we need some remaining number, otherwise, empty set of solutions
        if len(w)==0: return []
        rslt=[]
        # branch & bound part. Just try with any possibility for first number
        # combined with (recursion) any possibility for the rest
        for first in range(0, a+1, w[0]):
            for rest in bb(a-first, w[1:]):
                rslt.append((first//w[0],)+rest)
        return rslt
    

    On my machine, with a=60, that is 0.03 second, vs 10 for yours. (Plus, it has solutions that yours is missing. Such as the (0,20,0,0) mentioned in comments. I assume non-negative numbers)

    Edit: generator

    I see that this is also flagged generator, so I guess this is more what is expected

    def bb(a, w):
        if a==0:
            yield (0,)*len(w)
            return
        if len(w)==0: return
        for first in range(0, a+1, w[0]):
            for rest in bb(a-first, w[1:]):
                yield (first//w[0],)+rest