Search code examples
pythonpython-3.xperformanceknapsack-problemprocessing-efficiency

0/1 Knapsack Problem Simplified in Python


I have the following code that is performing too slowly. The idea is similar to the 0/1 knapsack problem, you have a given integer n and you have to find numbers in range 1 to n - 1 that when squared add up to n squared.

For example if n is 5, then it should output 3 , 4 because 3 ** 2 and 4 ** 2 = (25 or 5 ** 2). I have been struggling to understand how to make this more efficient and would like to know the concepts used to improve the efficiency of this type of program.

Some other examples: n = 8 [None] n = 30 [1, 3, 7, 29] n = 16 [2, 3, 5, 7, 13]

I found some posts regarding this but they seemed limited to two numbers where as my program needs to use as many as it needs to add up to the original number.

I watched some videos on the 0/1 knapsack problem. I struggled to apply the same concepts to my own program as the issue was quite different. They had things they could put in their bag that had a weight and profit.

This has all been hurting my brain for a few hours and if anyone could even point me in the right direction I would appreciate it highly, thankyou :)

from math import sqrt
def decompose(n):

    lst = []

    sets = []

    temp = []

    perm = {}

    out = []

    for i in range (n):
        lst.append(i**2)


    for i in lst:
        for x in sets:
            temp.append(i + x)
            perm[i + x] = (i, x)
        for x in temp:
            if x not in sets:
                sets.append(x)
        if i not in sets:
            sets.append(i)
        temp = []

    if n**2 not in perm.keys():
        return None

    for i in perm[n**2]:
        if str(i).isdigit():
            out.append(i)
        if i == ' ':
            out.append(i)


    for i in out:
        if i not in lst:
            out.remove(i)
            for i in perm[i]:
                if str(i).isdigit():
                    out.append(i)
                if i == ' ':
                    out.append(i)

    out.sort()

    return [sqrt(i) for i in out]

Solution

  • Here is a recursive program to find a decomposition. The speed probably is not optimal. Certainly it is not optimal for searching large ranges of inputs, as the current approach doesn't cache intermediate results.

    In this version of the function find_decomposition(n, k, uptonow, used) tries to find a decomposition for n2 only using the numbers from k to n-1, while we already have used the set of used numbers, and these numbers give a partial sum of uptonow. The function recursively tries 2 possibilities: either the solution includes k itself, or it doesn't include k. First try one possibility, if it works, return it. If not, try the other way. So, first try out a solution without using k. If it didn't work out, do a quick test to see whether only using k gives a solution. And if that also didn't work out, recursively try a solution that uses k, thus for which the set of used numbers now include k of for which the sum uptonow needs to be increased by k2.

    Many variations can be thought of:

    • Instead of running from 1 to n-1, k could run in the reverse order. Be careful with the test-conditions of the if-test.
    • Instead of first trying a solution that doesn't include k, start with trying a solution that does include k.

    Note that for large n, the function can run into maximal recursion depth. E.g. when n=1000, there are about 2999 possible subsets of the numbers to be recursively checked. This can lead to a recursion of 999 levels deep, which at some point is too much for the Python interpreter to handle.

    Probably the approach of first using up high numbers can be beneficial, as it quickly reduces the gap to fill. Luckily for large numbers there exist many possible solutions, so a solution can be found quickly. Note that in the general knapsack problem as described by @Kevin Wang, if no solutions exist, any approach with 999 numbers will take too long to finish.

    def find_decomposition(n, k=1, uptonow=0, used=[]):
    
        # first try without k
        if k < n-1:
            decomp = find_decomposition(n, k+1, uptonow, used)
            if decomp is not None:
                return decomp
    
        # now try including k
        used_with_k = used + [k]
        if uptonow + k * k == n * n:
            return used_with_k
        elif k < n-1 and uptonow + k * k + (k+1)*(k+1) <= n * n:
            # no need to try k if k doesn't fit together with at least one higher number
            return find_decomposition(n, k+1, uptonow+k*k, used_with_k)
        return None
    
    for n in range(5,1001):
        print(n, find_decomposition(n))
    

    Output:

    5 [3, 4]
    6 None
    7 [2, 3, 6]
    8 None
    9 [2, 4, 5, 6]
    10 [6, 8]
    11 [2, 6, 9]
    12 [1, 2, 3, 7, 9]
    13 [5, 12]
    14 [4, 6, 12]
    15 [9, 12]
    16 [3, 4, 5, 6, 7, 11]
    ...
    

    PS: This link contains code about a related problem, but where squares can be repeated: https://www.geeksforgeeks.org/minimum-number-of-squares-whose-sum-equals-to-given-number-n/