Search code examples
pythongeneratorpython-itertools

Flattening nested generator expressions


I'm trying to flatten a nested generator of generators but I'm getting an unexpected result:

>>> g = ((3*i + j for j in range(3)) for i in range(3))
>>> list(itertools.chain(*g))
[6, 7, 8, 6, 7, 8, 6, 7, 8]

I expected the result to look like this:

[0, 1, 2, 3, 4, 5, 6, 7, 8]

I think I'm getting the unexpected result because the inner generators are not being evaluated until the outer generator has already been iterated over, setting i to 2. I can hack together a solution by forcing evaluation of the inner generators by using a list comprehension instead of a generator expression:

>>> g = ([3*i + j for j in range(3)] for i in range(3))
>>> list(itertools.chain(*g))
[0, 1, 2, 3, 4, 5, 6, 7, 8]

Ideally, I would like a solution that's completely lazy and doesn't force evaluation of the inner nested elements until they're used.

Is there a way to flatten nested generator expressions of arbitrary depth (maybe using something other than itertools.chain)?

Edit:

No, my question is not a duplicate of Variable Scope In Generators In Classes. I honestly can't tell how these two questions are related at all. Maybe the moderator could explain why he thinks this is a duplicate.

Also, both answers to my question are correct in that they can be used to write a function that flattens nested generators correctly.

def flattened1(iterable):
    iter1, iter2 = itertools.tee(iterable)
    if isinstance(next(iter1), collections.Iterable):
        return flattened1(x for y in iter2 for x in y)
    else:
        return iter2

def flattened2(iterable):
    iter1, iter2 = itertools.tee(iterable)
    if isinstance(next(iter1), collections.Iterable):
        return flattened2(itertools.chain.from_iterable(iter2))
    else:
        return iter2

As far as I can tell with timeit, they both perform identically.

>>> timeit(test1, setup1, number=1000000)
18.173431718023494
>>> timeit(test2, setup2, number=1000000)
17.854709611972794

I'm not sure which one is better from a style standpoint either, since x for y in iter2 for x in y is a bit of a brain twister, but arguably more elegant than itertools.chain.from_iterable(iter2). Input is appreciated.

Regrettably, I was only able to mark one of the two equally good answers correct.


Solution

  • Instead of using chain(*g), you can use chain.from_iterable:

    >>> g = ((3*i + j for j in range(3)) for i in range(3))
    >>> list(itertools.chain(*g))
    [6, 7, 8, 6, 7, 8, 6, 7, 8]
    >>> g = ((3*i + j for j in range(3)) for i in range(3))
    >>> list(itertools.chain.from_iterable(g))
    [0, 1, 2, 3, 4, 5, 6, 7, 8]