I'm trying to figure out how the following backtracking solution works to generate all the permutations of integers given as a list:
def permutations(arr):
res = []
backtrack(arr, [], set(), res)
print(res)
def backtrack(arr, temp, visited, res):
if len(temp) == len(arr):
res.append(temp[:])
else:
for num in arr:
if num in visited: continue
visited.add(num)
temp.append(num)
backtrack(arr, temp, visited, res)
visited.remove(num)
temp.pop()
Executing with the following:
permutations([1, 2, 3])
The result is as expected:
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 1, 2], [3, 2, 1]]
What i don't understand is the temp.pop()
call at the end of the loop.
I know that pop()
discards the last element of the list, but why is it necessary here?
I'd really appreciate if someone could explain this to me.
It is necessary because on the next iteration the function will append
again an element, so if you didn't pop
beforehand the list would keep growing in size. Since backtrack
adds results only if the length of temp
equals the length of the original arr
it won't add anything beyond the first permutation [1, 2, 3]
(because from there on, the temp
list keeps growing).
We can just remove that line (temp.pop()
) and see what the results look like:
[[1, 2, 3]]
Now if we also change that results are added whenever len(temp) >= len(arr)
then we'll see how temp
grows in size:
[[1, 2, 3], [1, 2, 3, 3], [1, 2, 3, 3, 2], [1, 2, 3, 3, 2, 3]]
We get fewer results here because for each recursive call beyond [1, 2, 3]
the temp
list is immediately copied without ever reaching the for
loop.