I'm interested in the following problem mainly as a way to gain intuition about the backtracking algorithm, so I am not looking for alternative solutions that don't use backtracking.
Problem: find all n-element vectors such that the sum of their elements is less than or equal to some number K. Each element in the vector is an integer.
Example: if n = 3, and K = 10, then [9, 0, 0] and [5, 0, 5] are solutions, while [3, 1, 8] is not.
From this site, I've adapted python code to try to implement a solution.
Here is the general "backtracking engine" function:
def solve(values, safe_up_to, size):
solution = [None] * size
def extend_solution(position):
for value in values:
solution[position] = value
if safe_up_to(solution, position):
if position >= size-1 or extend_solution(position+1):
return solution
return None
return extend_solution(0)
And here is the function to check if the solution is "safe so far":
def safe_up_to(partial_solution, target = 100):
partial_solution = np.array(partial_solution) # convert to np array
# replace None with NaN
partial_solution = np.where(partial_solution == None, np.nan, partial_solution)
if np.nansum(partial_solution) <= target:
return True
else:
return False
However, when I run these two functions together, I only get a single vector of all zeroes.
solve(values=range(10), safe_up_to=safe_up_to, size=5)
How should I modify this code to get all feasible solutions?
Here is a gently modified version of your code. I tried to make it work changing as little as possible:
import numpy as np
from functools import partial
def solve(values, safe_up_to, size):
solution = [None] * size
def extend_solution(position):
for value in values:
solution[position] = value
if safe_up_to(solution):
if position >= size-1:
yield np.array(solution)
else:
yield from extend_solution(position+1)
solution[position] = None
return extend_solution(0)
def safe_up_to(target, partial_solution):
partial_solution = np.array(partial_solution) # convert to np array
# replace None with NaN
partial_solution = np.where(partial_solution == None, np.nan, partial_solution)
if np.nansum(partial_solution) <= target:
return True
else:
return False
for sol in solve(values=range(10), safe_up_to=partial(safe_up_to,4), size=2):
print(sol,sol.sum())
Prints:
[0 0] 0
[0 1] 1
[0 2] 2
[0 3] 3
[0 4] 4
[1 0] 1
[1 1] 2
[1 2] 3
[1 3] 4
[2 0] 2
[2 1] 3
[2 2] 4
[3 0] 3
[3 1] 4
[4 0] 4