Search code examples
pythonpython-3.xgeneratorpython-itertools

Generator and list return different results


I'm trying to take in a generator, where each element is a list/tuple/iterable, that have equal length and return a separate generator for each index of the elements.

When I hard code the indexes in split_feat2 below it works as expected. However when I use a list comprehension or append to a list and return it produces incorrect results.

I checked my logic and tried returning a list of lists instead of a list of generators by substituting the () for [] in the comprehension and it yielded correct results so I have no idea where the issue is.

Any insight into why it's acting the way it is would be greatly appreciated.

def split_feat2(gen):
    G = tee(gen, 2)
    return [(e[0] for e in G[0]), (e[1] for e in G[1])]

def split_feat(gen, n):
    G = tee(gen, n)
    return [(e[n] for e in g) for n, g in enumerate(G)]

def split_featlist(gen, n):
    G = tee(gen, n)
    return [[e[n] for e in g] for n, g in enumerate(G)]

test = lambda:((i^2,j+i) for i, j in enumerate(range(10)))

print("This is what I want")
t = split_feat2(test())
print(list(t[0]))
print(list(t[1]))
print(t)

print("\nBut I get this output")
t = split_feat(test(), 2)
print(list(t[0]))
print(list(t[1]))
print(t)

print("\nWhen I want this output but from generators instead of lists")
t = split_featlist(test(), 2)
print(list(t[0]))
print(list(t[1]))
print(t)

The above code outputs the below:

This is what I want
[2, 3, 0, 1, 6, 7, 4, 5, 10, 11]
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
[<generator object split_feat2.<locals>.<genexpr> at 0x00000219C794F7D8>, <generator object split_feat2.<locals>.<genexpr> at 0x00000219C794F200>]

But I get this output
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
[<generator object split_feat.<locals>.<listcomp>.<genexpr> at 0x00000219C791DB48>, <generator object split_feat.<locals>.<listcomp>.<genexpr> at 0x00000219C794F150>]

When I want this output but from generators instead of lists
[2, 3, 0, 1, 6, 7, 4, 5, 10, 11]
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
[[2, 3, 0, 1, 6, 7, 4, 5, 10, 11], [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]]

Solution

  • The problem is that the n variable has changed before you actually consume the generators. So when the function returned the list of generators it's at n - 1 (the function parameter n). So in your example both generators use the same index: 1. To understand what I mean look at this simple example:

    >>> list_of_list = [[0, 1]]*20
    >>> index = 1
    >>> gen = (item[index] for item in list_of_list)
    >>> print(next(gen))
    1
    >>> index = 0
    >>> print(next(gen))  # changing index "changed the generator"
    0
    

    In your case the loop constantly changed n (not manual intervention like in my example) but when the generators are executed it's fixed for all created generators at the same value.

    Solution

    You need to "fix" the current value of n in some way for each iteration. One possibility is map with operator.itemgetter:

    def split_feat(gen, n):
        G = tee(gen, n)
        return [map(itemgetter(n), g) for n, g in enumerate(G)]
    

    The itemgetter is immediatly created with the "current" n value, so the result will be as expected.

    That's not the only ways to achieve the desired result. You could also use a function that creates the generator. The function will "remember" the current n (like a closure) and also work like you expect:

    def split_feat(gen, n):
        G = tee(gen, n)
        def create_generator(it, n):
            return (item[n] for item in it)
        return [create_generator(g, n) for n, g in enumerate(G)]