Search code examples
pythonpython-3.xgroup-bygroupingpython-itertools

Group 'continuation' items in list. Is storing state in itertools groupby key function bad?


I'm new to Python and I'm trying to write a function that groups list items with None signaling continuation items like so:

>>> g([1, None, 1, 1, None, None, 1])
[[1, None], [1], [1, None, None], [1]]

My real data has much more complex items but I've simplified things to the core for this question.

This is my solution so far:

import itertools

# input
x = [1, None, 1, 1, None, None, 1]

# desired output from g(x)
y = [[1, None], [1], [1, None, None], [1]]


def f(x):
    if x is None:
        f.lastx = x
    else:
        if x != f.lastx:
            f.counter += 1
    return f.counter


def g(x):
    f.lastx = None
    f.counter = 0
    z = [list(g) for _, g in itertools.groupby(x, f)]
    return z


assert y == g(x)

This works but I know it's very ugly.

Is there a better (and more Pythonic) way to do this? E.g. without a stateful key function.


Solution

  • You could combine itertools.groupby and itertools.accumulate:

    >>> dat = [1, None, 1, 1, None, None, 1]
    >>> it = iter(dat)
    >>> acc = accumulate(x is not None for x in dat)
    >>> [[next(it) for _ in g] for _, g in groupby(acc)]
    [[1, None], [1], [1, None, None], [1]]
    

    This works because the accumulate will give us increasing intlike values at the start of every new group:

    >>> list(accumulate(x is not None for x in dat))
    [True, 1, 2, 3, 3, 3, 4]
    

    If you want to be able to handle a stream, just tee the iterator. The maximum increase in memory use is only of order the size of one group.

    def cgroup(source):
        it, it2 = tee(iter(source), 2)
        acc = accumulate(x is not None for x in it)
        for _,g in groupby(acc):
            yield [next(it2) for _ in g]
    

    This still gives

    >>> list(cgroup([1, None, 1, 1, None, None, 1]))
    [[1, None], [1], [1, None, None], [1]]
    

    but will work even with infinite sources:

    >>> stream = chain.from_iterable(repeat([1, 1, None]))
    >>> list(islice(cgroup(stream), 10))
    [[1], [1, None], [1], [1, None], [1], [1, None], [1], [1, None], [1], [1, None]]