Search code examples
pythonalgorithmoptimizationcoin-change

Similar to a coin change problem, but with "coin" repetitions and another optimization goal


The goal is to get a list of all possible variants of sequences of positive integer (nominal coin) values from a list L where each integer (nominal coin) value can be used multiple times (i.e. allowing repetitions) where the sum of the (nominal coin) values equals targetSum with the constraint that the total amount of numbers (coins) in the generated variant of the sequence is limited to the range between n and m (including n and m).

The code below is what I have came up with up to now, but it runs way too slow for the target purpose of being part of an optimization problem:

def allArrangementsOfIntegerItemsInLsummingUpTOtargetSum(L, targetSum, n=None, m=None):
    if n is None:   n  = 1
    if m is None:   m = targetSum
    lenL = len(L)
    # print(f"{targetSum=}, {L=}, {n=}, {m=}")
    Stack           = []
    # Initialize the Stack with the starting point for each element in L
    for i in range(lenL):
        currentSum  =   L[ i ]
        path        = [   L[ i ]   ]
        start       = 0         # Start from 0 allows revisiting of all items
        Stack.append(   (currentSum, path, start )   )  

    while Stack:
        currentSum, path, start = Stack.pop()
        # Check if the current path meets the criteria
        if currentSum == targetSum and n <= len(path) <= m:
            yield path
        if currentSum > targetSum or len(path) > m:
            continue  
        # ^ - NEXT please: stop exploring this path as it's not valid or complete

        # Continue to build the path by adding elements from L, starting from 0 index
        for i in range(len(L)):  # Change start to 0 if you want to include all permutations
            newSum = currentSum + L[ i  ]
            newPath = path + [ L[ i  ]  ]
            Stack.append((newSum, newPath, 0))  # Start from 0 allows every possibility
# def allArrangementsOfIntegerItemsInLsummingUpTOtargetSum
splitsGenerator = allArrangementsOfIntegerItemsInLsummingUpTOtargetSum

L = [ 13, 17, 23, 24, 25 ] ; targetSum = 30 ; m=1 ; n=30
sT=T(); listOfSplits = list(splitsGenerator(L, targetSum) ); eT=T(); dT=eT-sT
print( f"{dT=:.6f} s  ->  L = [ 13, 17, 23, 24, 25 ] ; targetSum =   30  ->  {listOfSplits}")

L = [ 60, 61, 62, 63, 64 ] ; targetSum = 600 # m=1 ; n=6000 are DEFAULT values
sT=T(); listOfSplits = list(splitsGenerator(L, targetSum) ); eT=T(); dT=eT-sT
print( f"{dT=:.6f} s  ->  L = [ 60, 61, 62, 63, 64 ] ; targetSum = 6000  ->  {listOfSplits}")

giving as output:

dT=0.000047 s  ->  L = [ 13, 17, 23, 24, 25 ] ; targetSum =  30  ->  [[17, 13], [13, 17]]
dT=5.487905 s  ->  L = [ 60, 61, 62, 63, 64 ] ; targetSum = 600  ->  [[60, 60, 60, 60, 60, 60, 60, 60, 60, 60]]

where you can see that the algorithm needs over 5 seconds to come up with the only possible variant making same kind of calculation for targetSum = 6000 lasting "forever".

Any idea of how to write code able to come up at least an order of magnitude faster with a result?

I searched the Internet already for weeks and found that all of the knapsack, coin change and dynamic programming based known optimization approaches are not covering such a basic task which special case is to be used in order to divide a list of items into sub-lists (partitions) with sizes ranging from a to b for the purpose of optimization of an overall weight function which uses values obtained from a local weight function calculating single weights out of the items in each of the sub-list (partition).

Notice that both sequences [[17, 13], [13, 17]] consist of same (nominal coin) values. In other words the order of the values matter giving two variants instead of only one if the order were deliberate.

What I am interested in is another algorithm able to come up with the result much faster, so the programming language in which this algorithm is written or expressed is secondary, but I would prefer Python to describe the algorithm as this will make it as a side-effect easy to test it against the already provided code.


considering code provided in the answer by M.S:

def count_combinations(L, targetSum, n=None, m=None):
    if n is None:   n  = 1
    if m is None:   m = targetSum
    # Create the DP table with dimensions (m+1) x (targetSum+1)
    dp = [[0] * (targetSum + 1) for _ in range(m + 1)]
    dp[0][0] = 1  # Base case
    
    # Update the DP table
    for num in L:
        for count in range(m, n - 1, -1):  # Go from m to n
            for sum_val in range(num, targetSum + 1):
                dp[count][sum_val] += dp[count - 1][sum_val - num]
    
    # Extract all valid combinations
    result = []
    for count in range(n, m + 1):
        if dp[count][targetSum] > 0:
            result.extend(extract_combinations(L, count, targetSum, dp))
    
    return result

def extract_combinations(L, count, targetSum, dp):
    combinations = []

    def backtrack(current_count, current_sum, combination):
        if current_count == 0 and current_sum == 0:
            combinations.append(combination.copy())
            return
        if current_count <= 0 or current_sum <= 0:
            return
        # Backtrack from the last number considered
        for num in L:
            if current_sum >= num and dp[current_count-1][current_sum-num] > 0:
                combination.append(num)
                backtrack(current_count-1, current_sum-num, combination)
                combination.pop()
    
    backtrack(count, targetSum, [])
    return combinations

splitsGenerator = count_combinations

L = [ 13, 17, 23, 24, 25 ] ; targetSum = 30 ; m=1 ; n=30
sT=T(); listOfSplits = list(splitsGenerator(L, targetSum) ); eT=T(); dT=eT-sT
print( f"{dT=:.6f} s  ->  L = [ 13, 17, 23, 24, 25 ] ; targetSum =   30  ->  {listOfSplits}")

L = [ 60, 61, 62, 63, 64 ] ; targetSum = 600 # m=1 ; n=6000 are DEFAULT values
sT=T(); listOfSplits = list(splitsGenerator(L, targetSum) ); eT=T(); dT=eT-sT
print( f"{dT=:.6f} s  ->  L = [ 60, 61, 62, 63, 64 ] ; targetSum = 6000  ->  {listOfSplits}")

which outputs:

dT=0.000469 s  ->  L = [ 13, 17, 23, 24, 25 ] ; targetSum =   30  ->  [[13, 17], [17, 13]]
dT=0.285951 s  ->  L = [ 60, 61, 62, 63, 64 ] ; targetSum = 6000  ->  []

failing to provide a result in case of the second example.


considering the by brevity excelling code provided in an answer by chrslg:

def bb3( L, target, partial=[]):
    if target < 0  : return []
    if target ==0 : return [partial]
    if target<L[0] : return []
    sols=[]
    for c in L:
        if target>=c:
            sols.extend( bb3(L, target-c, partial+[c]) )
    return sols

splitsGenerator = bb3

L = [ 13, 17, 23, 24, 25 ] ; targetSum = 30 ; m=1 ; n=30 # m=1 ; n=30 are DEFAULT values
sT=T(); listOfSplits = list(splitsGenerator(L, targetSum) ); eT=T(); dT=eT-sT
print( f"{dT=:.6f} s  ->  L = [ 13, 17, 23, 24, 25 ] ; targetSum =   30  ->  {listOfSplits}")

L = [ 60, 61, 62, 63, 64 ] ; targetSum = 600 # m=1 ; n=6000 are DEFAULT values
sT=T(); listOfSplits = list(splitsGenerator(L, targetSum) ); eT=T(); dT=eT-sT
print( f"{dT=:.6f} s  ->  L = [ 60, 61, 62, 63, 64 ] ; targetSum = 6000  ->  {listOfSplits}")

outputs:

dT=0.000012 s  ->  L = [ 13, 17, 23, 24, 25 ] ; targetSum =   30  ->  [[13, 17], [17, 13]]
dT=0.933661 s  ->  L = [ 60, 61, 62, 63, 64 ] ; targetSum = 6000  ->  [[60, 60, 60, 60, 60, 60, 60, 60, 60, 60]]

running 4 to 6 times faster than code I have provided as reference.

Considering that it needs almost one second to come up with an obvious result in second case it is still beyond the expectation of at least an order of magnitude.

I have started with the recursive approach but because recursion can run into recursion limits the code I provided in the question is a re-written version of the recursive one and maybe therefore got slower because of replacing recursion with Stack?


considering the by speed excelling code provided in an answer by btilly:

dT=0.000140 s  ->  L = [13, 17, 23, 24, 25] ; targetSum =   30 ->  [[13, 17], [17, 13]] 
dT=0.003229 s  ->  L = [60, 61, 62, 63, 64] ; targetSum =  600 ->  [[60, 60, 60, 60, 60, 60, 60, 60, 60, 60]] 

now arrived at an actual some orders of magnitude speedup for edge cases I assumed it should be possible and therefore asked this question.


Let's provide final comparison:

Code in my question:

dT=0.000047 s  ->  L = [ 13, 17, 23, 24, 25 ] ; targetSum =  30  ->  [[17, 13], [13, 17]]
dT=5.487905 s  ->  L = [ 60, 61, 62, 63, 64 ] ; targetSum = 600  ->  [[60, 60, 60, 60, 60, 60, 60, 60, 60, 60]]

Code by chrslg :

dT=0.000012 s  ->  L = [ 13, 17, 23, 24, 25 ] ; targetSum =   30  ->  [[13, 17], [17, 13]]
dT=0.933661 s  ->  L = [ 60, 61, 62, 63, 64 ] ; targetSum = 6000  ->  [[60, 60, 60, 60, 60, 60, 60, 60, 60, 60]]

Code by btilly :

dT=0.000140 s  ->  L = [13, 17, 23, 24, 25] ; targetSum =   30 ->  [[13, 17], [17, 13]] 
dT=0.003229 s  ->  L = [60, 61, 62, 63, 64] ; targetSum =  600 ->  [[60, 60, 60, 60, 60, 60, 60, 60, 60, 60]] 

As can be seen from above there is no overall better solution as chrslg code is faster at small examples and btilly code on large ones, so the optimal solution will be using both depending on the input parameter, but ... I was actually after being able to get the larger example done much faster and the btilly code is 1700 times faster what is really impressive, isn't it?


Solution

  • Here is the much faster answer with dynamic programming.

    The main problem for dynamic programming is that there are too many states to keep track of if we keep track of taking things in any possible order. So what we have to do is decide how many of which values to take, and also keep track of in how many orders we might have done that.

    Then when we have the answer, we have to figure out both what values, and then which order to put them in.

    The full reasoning for the bookkeeping is long and complex. But I'll do three examples showing that I produce correct answers. And that I can even give the solutions both in order, or any random one by index!

    The number of operations to create the structure is O(len(coins) * max_len^2 * target). But there is a complication. Some of those operations are arithmetic calculation. If there are a large number of solutions, that can force you to big integer arithmetic. Which can be another O(m) factor. :-(

    The time to list the solutions is O(m) per solution. There is a permutation.copy() line to be sure that where it is returned to will not change permutation. If you can guarantee that this will not happen and remove that line, with most inputs the time to list solutions is O(1) per solution.

    from sortedcontainers import SortedDict
    
    class TargetNode:
    
        _repetition_factor = []
    
        # This node stores a recursive data structure describing the values in a
        # solution, and how many ways it can be put into an ordered solution.
        def __init__(self, term=None, this_len=0, prev=None, prev_solutions=None):
            # term appears
            self.term = term
            # this_len number of times
            self.this_len = this_len
            # for each thing we had before.
            self.prev = prev
            if prev is None:
                # Nothing appeared before - I'm it.
                self.prev_len = 0
                self.total_len = this_len
                # Exactly one solution before - nothing.
                self.prev_count = 1
                # For that one solution, I only have one ordered solution.
                self.repetition_factor = 1
                if self.term is None:
                    self.total_sum = 0
                else:
                    self.total_sum = term * this_len
                    self.prev_len = 0
            else:
                # Before me were prev_len things
                self.prev_len = prev.total_len
                # And I have a total length
                self.total_len = this_len + self.prev_len
                # This solution will add up to
                self.total_sum = term * this_len + prev.total_sum
                # Here's how many patterns we get from prev.
                self.prev_count = prev.total_count
                # Each of which, after inserting term this_len times turns
                # into repetition_factor ordered solutions.
                repetition_factor = self.find_repetition_factor(self.prev_len, self.this_len)
                self.repetition_factor = repetition_factor
    
            # Giving a total count.
            self.total_count = self.prev_count * self.repetition_factor
            # Oh, but we may have other solutions that didn't have this_len of term.
            self.prev_solutions = []
            if prev_solutions is not None:
                self.prev_solutions.extend(prev_solutions)
                self.total_count += sum((soln.total_count for soln in self.prev_solutions))
    
        # This is a convoluted way of choosing and memozing:
        #
        #   prev_len + this_len choose this_len
        #
        # which is the number of ways to make a new list by
        # inserting this_len elements into a list which already
        # had prev_len elements.
        #
        # I do it this way for two reasons:
        #
        # 1. Avoids repeat math whenever possible.
        # 2. Keeps numbers smaller than the usual factorial formula.
        def find_repetition_factor(self, prev_len, this_len):
            if this_len < prev_len:
                # Swap.
                (prev_len, this_len) = (this_len, prev_len)
            while len(self._repetition_factor) <= prev_len:
                self._repetition_factor.append([1])
            this_row = self._repetition_factor[prev_len]
            while len(this_row) <= this_len:
                this_row.append((this_row[-1] * (prev_len + len(this_row))) // len(this_row))
            return this_row[this_len]
    
        # Let's track another solution in this one.
        def add_prev_solution (self, other):
            if other is not None:
                if self.total_len != other.total_len:
                    self.total_len = None # Don't try to use this again!
                self.prev_solutions.append(other)
                self.total_count += other.total_count
            return self
    
        # Iterator through all of the value/frequency combinations.
        def value_count_iter (self):
            if 0 == self.total_len:
                yield []
            else:
                # Using a hand-rolled stack to avoid Python's recursion limit.
                value_count = SortedDict()
                node_stack = [self]
                index_stack = [-1]
                while 0 < len(node_stack):
                    node = node_stack[-1]
                    index = index_stack[-1]
                    if -1 == index:
                        if 0 == node.total_len:
                            yield value_count
                            node_stack.pop()
                            index_stack.pop()
                        else:
                            value_count[node.term] = node.this_len
                            node_stack.append(node.prev)
                            index_stack[-1] = 0
                            index_stack.append(-1)
                    else:
                        if index == 0:
                            value_count.popitem()
                        if index < len(node.prev_solutions):
                            node_stack.append(node.prev_solutions[index])
                            index_stack[-1] += 1
                            index_stack.append(-1)
                        else:
                            node_stack.pop()
                            index_stack.pop()
    
        # Iterator through all of my solutions.
        # Basically go through value_count combinations, then put them
        # together in every possible order.
        def __iter__ (self):
            for value_count in self.value_count_iter():
                # We have count by value. Find permutations. Use an
                # explicit stack to avoid recursion limits.
                permutation = []
                # The index stack always contains the position of the last
                # value looked at for the next permutation. So this means,
                # "We've looked at none so far."
                index_stack = [-1]
    
                while 0 < len(index_stack):
                    # Get the current recursion level.
                    index = index_stack.pop()
                    if 0 == len(value_count):
                        # This copy is the most expensive thing we do!
                        # This is O(n). If you get rid of it, on a random
                        # input we average O(1).
                        yield permutation.copy()
                        # Finish with the previous recursion level as well.
                        old_value = permutation.pop()
                        value_count[old_value] = 1
                        index_stack.pop()
                    else:
                        # Do we need to replace the value into value_count?
                        if 0 <= index:
                            old_value = permutation.pop()
                            value_count[old_value] = value_count.get(old_value, 0) + 1
    
                        # Advance.
                        index += 1
    
                        # Do we have another value to go to?
                        if index < len(value_count):
                            value, count = value_count.peekitem(index)
                            permutation.append(value)
                            if count <= 1:
                                value_count.pop(value)
                            else:
                                value_count[value] = count - 1
                            # Replace the current frame so we come back.
                            index_stack.append(index)
                            # And start the new frame.
                            index_stack.append(-1)
    
        # Let's calculate a random solution. This lets us sample
        # very large sets.
        def find_ordered_solution (self, index):
            if index < self.total_count:
                # Use explicit stacks to avoid recursion limits.
                repetition_stack = []
                extra_index_stack = []
    
                # And we'll modify these as we go.
                node = self
                value_count = SortedDict()
                while 0 != node.total_len:
                    if node.prev_count * node.repetition_factor <= index:
                        index -= node.prev_count * node.repetition_factor
                        for next_node in node.prev_solutions:
                            if next_node.total_count <= index:
                                index -= next_node.total_count
                            else:
                                node = next_node
                                break # To the accounting for node.
                    else:
                        # Now we're somewhere in node's solutions.
                        value_count[node.term] = node.this_len
                        # How many times did we go to node.prev?
                        next_index = index // node.repetition_factor
                        repetition_stack.append(node.repetition_factor)
                        extra_index_stack.append(index - next_index * node.repetition_factor)
                        index = next_index
                        node = node.prev
    
                # Now which permutation are we on?
                # We have it encoded in a version of
                # https://en.wikipedia.org/wiki/Factorial_number_system
                index = 0
                while 0 < len(repetition_stack):
                    index = index * repetition_stack.pop() + extra_index_stack.pop()
    
                # Now we know the values, and we know where it is in the list...find it.
                permutation = []
                # The index of the value we will try next.
                i = 0
                while 0 < len(value_count):
                    # Each iteration will either add a value to permutation,
                    # or discard a choice for the next value.
                    value, count = value_count.peekitem(i)
                    if count <= 1:
                        value_count.pop(value)
                    else:
                        value_count[value] = count-1
                    # How many solutions have value here? Do a multi-choose.
                    count_with_value = 1
                    total_seen = 0
                    for count_inner in value_count.values():
                        for j in range(count_inner):
                            total_seen += 1
                            count_with_value = (count_with_value * total_seen) // (j + 1)
    
                    if index < count_with_value:
                        permutation.append(value)
                        i = 0
                    else:
                        value_count[value] = count
                        index -= count_with_value
                        i += 1
    
                return permutation
    
    
    def find_solution_structure (coins, target, min_len, max_len):
        coins = sorted(coins) # Now I know where largest and smallest are...
        # By (total_sum, total_len), the node that encodes it.
        # We start with the empty sum.
        reachable = {(0, 0): TargetNode()}
        answer = None
    
        # For each coin (largest to smallest) that we might use.
        i = len(coins)
        while 0 < i:
            i -= 1
    
            # We will store the edits to make to reachable while
            # we run through it, then edit later. It is dangerous
            # to edit a collection while searching throught it.
            this_reachable = {}
            to_remove = []
    
            # Now do the dynamic programming thing for adding some
            # of this term.
            for key, node in reachable.items():
                total, count = key
                # Pruning sanity check.
                if total + (max_len - count) * coins[i] < target:
                    # Cannot possibly sum high enough from here.
                    to_remove.append(key)
                    continue
    
                j = 0
                while j < max_len - count:
                    j += 1
                    new_node = TargetNode(term=coins[i], this_len=j, prev=node)
                    # Do we have a solution?
                    if new_node.total_sum == target:
                        # Did we find part of our answer?
                        if new_node.total_sum == target and min_len <= new_node.total_len:
                            answer = new_node.add_prev_solution(answer)
                        # Regardless, we're done with j.
                        break # out of j while-loop
    
                    # What is the smallest answer we can be leading to?
                    min_final_total = new_node.total_sum
                    if new_node.total_len < min_len:
                        min_final_total += (min_len - new_node.total_len) * coins[0]
                    else:
                        min_final_total += coins[0]
                    if target < min_final_total:
                        break # out of j while-loop
    
                    # Record it.
                    new_key = (new_node.total_sum, new_node.total_len)
                    this_reachable[new_key] = new_node.add_prev_solution(this_reachable.get(new_key, None))
    
            # Now record it permanently.
            for key in to_remove:
                reachable.pop(key)
            for key, node in this_reachable.items():
                reachable[key] = node.add_prev_solution(reachable.get(key, None))
    
        return answer
    
    for coins, target, min_len, max_len in (
            ([ 13, 17, 23, 24, 25 ], 30, 1, 10),
            ([ 60, 61, 62, 63, 64 ], 600, 1, 10),
            ([ 1, 2, 3, 4, 5 ], 13, 3, 4)
            ):
        node = find_solution_structure(coins, target, min_len, max_len)
        print(coins, target, min_len, max_len, node, node.total_count)
        for x in node.value_count_iter():
            print(" ", x)
        print("and all orders")
        for c, x in enumerate(node):
            print(">", c, x, node.find_ordered_solution(c))