Search code examples
pythonpython-itertools

Python groupby behaves strangely


from itertools import groupby

source = [ [1,2], [1,3], [2, 1] ]
gby = groupby(source, lambda x: x[0])

print 'as list'
for key, vals in list(gby):
    print 'key {}'.format(key)
    for val in vals:
        print '  val {}'.format(val)

print

print 'as iter'
gby = groupby(source, lambda x: x[0])
for key, vals in gby:
    print 'key {}'.format(key)
    for val in vals:
        print '  val {}'.format(val)

Results:

as list
key 1
key 2
  val [2, 1]

as iter
key 1
  val [1, 2]
  val [1, 3]
key 2
  val [2, 1]

What is wrong with list(gby)? I would expect list to be pure function, how does it manage to corrupt internal state?


Solution

  • The documentation makes a note about this:

    The returned group is itself an iterator that shares the underlying iterable with groupby(). Because the source is shared, when the groupby() object is advanced, the previous group is no longer visible. So, if that data is needed later, it should be stored as a list:

    groups = []
    uniquekeys = []
    data = sorted(data, key=keyfunc)
    for k, g in groupby(data, keyfunc):
        groups.append(list(g))      # Store group iterator as a list
        uniquekeys.append(k)
    

    You're exhausting the groupby object (by turning it into a list) prior to trying to iterate over the returned group iterators, so all the groups other than the last group are lost.

    The reason for this is easier to figure out by looking at the Python implementation of the function:

    class groupby(object):
        # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
        # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D
        def __init__(self, iterable, key=None):
            if key is None:
                key = lambda x: x
            self.keyfunc = key
            self.it = iter(iterable)
            self.tgtkey = self.currkey = self.currvalue = object()
        def __iter__(self):
            return self
        def next(self):
            while self.currkey == self.tgtkey:
                self.currvalue = next(self.it)
                self.currkey = self.keyfunc(self.currvalue)
            self.tgtkey = self.currkey
            return (self.currkey, self._grouper(self.tgtkey))
        def _grouper(self, tgtkey):  # This is the "group" iterator
            while self.currkey == tgtkey:  # self.currkey != tgtkey if you advance groupby and then try to use this object.
                yield self.currvalue
                self.currvalue = next(self.it)
                self.currkey = self.keyfunc(self.currvalue)
    

    Calling next(groupby) advances the internal pointer to the underlying iterable (self.currvalue) to the next key, then returns the current key (self.currkey) and the _grouper iterator. _grouper takes the current key as an argument (called tgtkey), and will yield values (and recalculate self.currkey), until self.currkey is different than the tgtkey, meaning its returned all the values corresponding to the current key. So, if you advance groupby prior to using a _grouper object, self.currkey will never be equal to tgtkey, so the _grouper iterator will return nothing.

    If for some reason you do need to store the groupby results in a list, you have to do it like this:

    gby_list = []
    for key, vals in gby:
        gby_list.append(key, list(vals))
    

    Or:

    gby_list = [key, list(vals) for key, vals in gby]