I was looking at the accepted solution to this question, which provides a Python implementation of an algorithm for producing unique permutations in lexicographic order. I have a somewhat shortened implementation:
def permutations(seq):
seq = sorted(seq)
while True:
yield seq
k = l = None
for k in range(len(seq) - 1):
if seq[k] < seq[k + 1]:
l = k + 1
break
else:
return
(seq[k], seq[l]) = (seq[l], seq[k])
seq[k + 1:] = seq[-1:k:-1]
What's really strange for me is that if I call list
on the output of this function, I get wrong results. However, if I iterate over the results of this function one at a time, I get the expected results.
>>> list(permutations((1,2,1)))
[[2, 1, 1], [2, 1, 1], [2, 1, 1]]
>>> for p in permutations((1,2,1)):
... print(p)
...
[1, 1, 2]
[1, 2, 1]
[2, 1, 1]
^^^What the?! Another example:
>>> list(permutations((1,2,3)))
[[3, 2, 1], [3, 2, 1], [3, 2, 1], [3, 2, 1]]
>>> for p in permutations((1,2,3)):
... print(p)
...
[1, 2, 3]
[2, 3, 1]
[3, 1, 2]
[3, 2, 1]
And list comprehension also yields the incorrect values:
>>> [p for p in permutations((1,2,3))]
[[3, 2, 1], [3, 2, 1], [3, 2, 1], [3, 2, 1]]
I have no idea what's going on here! I've not seen this before. I can write other functions that use generators and I don't run into this:
>>> def seq(n):
... for i in range(n):
... yield i
...
>>> list(seq(5))
[0, 1, 2, 3, 4]
What's going on in my example above that causes this?
You modify seq
within the generator, after you've yielded it. You keep yielding the same object, and modifying it.
(seq[k], seq[l]) = (seq[l], seq[k]) # this mutates seq
seq[k + 1:] = seq[-1:k:-1] # this mutates seq
Note, your list
contains the same object multiple times:
In [2]: ps = list(permutations((1,2,1)))
In [3]: ps
Out[3]: [[2, 1, 1], [2, 1, 1], [2, 1, 1]]
In [4]: [hex(id(p)) for p in ps]
Out[4]: ['0x105cb3b48', '0x105cb3b48', '0x105cb3b48']
So, try yield
ing a copy:
def permutations(seq):
seq = sorted(seq)
while True:
yield seq.copy()
k = None
l = None
for k in range(len(seq) - 1):
if seq[k] < seq[k + 1]:
l = k + 1
break
else:
return
(seq[k], seq[l]) = (seq[l], seq[k])
seq[k + 1:] = seq[-1:k:-1]
And, voila:
In [5]: def permutations(seq):
...: seq = sorted(seq)
...: while True:
...: yield seq.copy()
...: k = None
...: l = None
...: for k in range(len(seq) - 1):
...: if seq[k] < seq[k + 1]:
...: l = k + 1
...: break
...: else:
...: return
...:
...: (seq[k], seq[l]) = (seq[l], seq[k])
...: seq[k + 1:] = seq[-1:k:-1]
...:
In [6]: ps = list(permutations((1,2,1)))
In [7]: ps
Out[7]: [[1, 1, 2], [1, 2, 1], [2, 1, 1]]
As to why print
ing in a for-loop doesn't reveal this behavior, it's because at that moment in the iteration seq
has the "correct" value, so consider:
In [10]: result = []
...: for i, x in enumerate(permutations((1,2,1))):
...: print("iteration ", i)
...: print(x)
...: result.append(x)
...: print(result)
...:
iteration 0
[1, 1, 2]
[[1, 1, 2]]
iteration 1
[1, 2, 1]
[[1, 2, 1], [1, 2, 1]]
iteration 2
[2, 1, 1]
[[2, 1, 1], [2, 1, 1], [2, 1, 1]]