Search code examples
pythonalgorithmnumpybacktracking

Backtracking to find n-element vectors whose elements add up to less than K


I'm interested in the following problem mainly as a way to gain intuition about the backtracking algorithm, so I am not looking for alternative solutions that don't use backtracking.

Problem: find all n-element vectors such that the sum of their elements is less than or equal to some number K. Each element in the vector is an integer.

Example: if n = 3, and K = 10, then [9, 0, 0] and [5, 0, 5] are solutions, while [3, 1, 8] is not.

From this site, I've adapted python code to try to implement a solution.

Here is the general "backtracking engine" function:

def solve(values, safe_up_to, size):

    solution = [None] * size

    def extend_solution(position):
        for value in values:
            solution[position] = value
            if safe_up_to(solution, position):
                if position >= size-1 or extend_solution(position+1):
                    return solution
        return None

    return extend_solution(0)

And here is the function to check if the solution is "safe so far":

def safe_up_to(partial_solution, target = 100): 
   partial_solution = np.array(partial_solution)  # convert to np array 

   # replace None with NaN
   partial_solution = np.where(partial_solution == None, np.nan, partial_solution)

   if np.nansum(partial_solution) <= target: 
       return True
   else: 
       return False 

However, when I run these two functions together, I only get a single vector of all zeroes.

solve(values=range(10), safe_up_to=safe_up_to, size=5)

How should I modify this code to get all feasible solutions?


Solution

  • Here is a gently modified version of your code. I tried to make it work changing as little as possible:

    import numpy as np
    from functools import partial
    
    def solve(values, safe_up_to, size):
    
        solution = [None] * size
    
        def extend_solution(position):
            for value in values:
                solution[position] = value
                if safe_up_to(solution):
                    if position >= size-1:
                        yield np.array(solution)
                    else:
                        yield from extend_solution(position+1)
            solution[position] = None
    
        return extend_solution(0)
    
    def safe_up_to(target, partial_solution): 
       partial_solution = np.array(partial_solution)  # convert to np array 
    
       # replace None with NaN
       partial_solution = np.where(partial_solution == None, np.nan, partial_solution)
    
       if np.nansum(partial_solution) <= target: 
           return True
       else: 
           return False 
    
    for sol in solve(values=range(10), safe_up_to=partial(safe_up_to,4), size=2):
        print(sol,sol.sum())
    

    Prints:

    [0 0] 0
    [0 1] 1
    [0 2] 2
    [0 3] 3
    [0 4] 4
    [1 0] 1
    [1 1] 2
    [1 2] 3
    [1 3] 4
    [2 0] 2
    [2 1] 3
    [2 2] 4
    [3 0] 3
    [3 1] 4
    [4 0] 4