Search code examples
pythonlistgeneratoryielditerable

Python: calling list() on generator object produces incorrect result


I was looking at the accepted solution to this question, which provides a Python implementation of an algorithm for producing unique permutations in lexicographic order. I have a somewhat shortened implementation:

def permutations(seq):
    seq = sorted(seq)
    while True:
        yield seq
        k = l = None
        for k in range(len(seq) - 1):
            if seq[k] < seq[k + 1]:
                l = k + 1
                break
        else:
            return

        (seq[k], seq[l]) = (seq[l], seq[k])
        seq[k + 1:] = seq[-1:k:-1]

What's really strange for me is that if I call list on the output of this function, I get wrong results. However, if I iterate over the results of this function one at a time, I get the expected results.

>>> list(permutations((1,2,1)))
[[2, 1, 1], [2, 1, 1], [2, 1, 1]]
>>> for p in permutations((1,2,1)):
...   print(p)
... 
[1, 1, 2]
[1, 2, 1]
[2, 1, 1]

^^^What the?! Another example:

>>> list(permutations((1,2,3)))
[[3, 2, 1], [3, 2, 1], [3, 2, 1], [3, 2, 1]]
>>> for p in permutations((1,2,3)):
...   print(p)
... 
[1, 2, 3]
[2, 3, 1]
[3, 1, 2]
[3, 2, 1]

And list comprehension also yields the incorrect values:

>>> [p for p in permutations((1,2,3))]
[[3, 2, 1], [3, 2, 1], [3, 2, 1], [3, 2, 1]]

I have no idea what's going on here! I've not seen this before. I can write other functions that use generators and I don't run into this:

>>> def seq(n):
...   for i in range(n):
...     yield i
... 
>>> list(seq(5))
[0, 1, 2, 3, 4]

What's going on in my example above that causes this?


Solution

  • You modify seq within the generator, after you've yielded it. You keep yielding the same object, and modifying it.

        (seq[k], seq[l]) = (seq[l], seq[k]) # this mutates seq
        seq[k + 1:] = seq[-1:k:-1] # this mutates seq
    

    Note, your list contains the same object multiple times:

    In [2]: ps = list(permutations((1,2,1)))
    
    In [3]: ps
    Out[3]: [[2, 1, 1], [2, 1, 1], [2, 1, 1]]
    
    In [4]: [hex(id(p)) for p in ps]
    Out[4]: ['0x105cb3b48', '0x105cb3b48', '0x105cb3b48']
    

    So, try yielding a copy:

    def permutations(seq):
        seq = sorted(seq)
        while True:
            yield seq.copy()
            k = None
            l = None
            for k in range(len(seq) - 1):
                if seq[k] < seq[k + 1]:
                    l = k + 1
                    break
            else:
                return
    
            (seq[k], seq[l]) = (seq[l], seq[k])
            seq[k + 1:] = seq[-1:k:-1]
    

    And, voila:

    In [5]: def permutations(seq):
       ...:     seq = sorted(seq)
       ...:     while True:
       ...:         yield seq.copy()
       ...:         k = None
       ...:         l = None
       ...:         for k in range(len(seq) - 1):
       ...:             if seq[k] < seq[k + 1]:
       ...:                 l = k + 1
       ...:                 break
       ...:         else:
       ...:             return
       ...:
       ...:         (seq[k], seq[l]) = (seq[l], seq[k])
       ...:         seq[k + 1:] = seq[-1:k:-1]
       ...:
    
    In [6]: ps = list(permutations((1,2,1)))
    
    In [7]: ps
    Out[7]: [[1, 1, 2], [1, 2, 1], [2, 1, 1]]
    

    As to why printing in a for-loop doesn't reveal this behavior, it's because at that moment in the iteration seq has the "correct" value, so consider:

    In [10]: result = []
        ...: for i, x in enumerate(permutations((1,2,1))):
        ...:     print("iteration ", i)
        ...:     print(x)
        ...:     result.append(x)
        ...:     print(result)
        ...:
    iteration  0
    [1, 1, 2]
    [[1, 1, 2]]
    iteration  1
    [1, 2, 1]
    [[1, 2, 1], [1, 2, 1]]
    iteration  2
    [2, 1, 1]
    [[2, 1, 1], [2, 1, 1], [2, 1, 1]]