Search code examples
pythonalgorithmbacktracking

how to adjust code to use backtracking algorithm to solve combination issue


Problem describe:

Return all combinations of a array. for example, there is an array [1, 2, 3], its results is:

[]
[1]  [2]  [3]
[1, 2]  [1, 3]  [2, 3]
[1, 2, 3]

Yes I know there are lot of ways to solve this. but I am trying to solve it with backtracking algorithm. below is my code:

def p(arr):
    ret = []
    #using visited boolean array to avoid duplicate traverse and backtracking.
    visited = [False] * len(arr)
    def dfs(start_idx, temp)
        ret.append(temp)
        for i in range(start_idx, len(arr)):
            if not visited[i]:
                visited[i] = True
                dfs(start_idx + 1, temp + [arr[i]])
                visited[i] = False
    dfs(0, [])
    return ret 

It returns [[], [1], [1, 2], [1, 2, 3], [1, 3], [2], [2, 3], [3], [3, 2]], which has a wrong answer [3, 2]

From my understanding, dfs + backtracking should only traverse the array in one direction which is left to right. but clearly [3, 2] is the reverse direction.

How to understand this and how to fix this with my code?


Solution

  • Your algorithm uses a list of booleans to keep track of what elements are selected. But that is not the good way to do it: once you selected an element i, you should make sure that you can only select elements with an index j > i.

    You seem to do this with start_idx, but actually in the recursive call you *only increment start_idx.

    So a quick fix is to set start_index to i+1:

    def p(arr):
        ret = []
        #using visited boolean array to avoid duplicate traverse and backtracking.
        visited = [False] * len(arr)
        def dfs(start_idx, temp):
            ret.append(temp)
            for i in range(start_idx, len(arr)):
                if not visited[i]:
                    visited[i] = True
                    dfs(i + 1, temp + [arr[i]])  # i instead of start_idx
                    visited[i] = False
        dfs(0, [])
        return ret

    This now yields visited obsolete, so we can remove these checks:

    def p(arr):
        ret = []
        def dfs(start_idx, temp):
            ret.append(temp)
            for i in range(start_idx, len(arr)):
                dfs(i + 1, temp + [arr[i]])
        dfs(0, [])
        return ret

    That being said, I would suggest using itertools.combinations.