Search code examples
pythonmachine-learninggeneratorconv-neural-networkimperative-programming

How to modify last element of a generator in python?


I have a generator and I want to modify the last element of the generator. I want to replace the last element with another element. I know how to retrieve the last element, but not how to modify it.

What would be the best way to approach this?

For more context, this is what I want to do:

for child in alexnet.children():
    for children_of_child in child.children():
         print(children_of_child);

My generator object is: children_of_child and for the second child all its children are:

Dropout(p=0.5)
Linear(in_features=9216, out_features=4096, bias=True)
ReLU(inplace)
Dropout(p=0.5)
Linear(in_features=4096, out_features=4096, bias=True)
ReLU(inplace)
Linear(in_features=4096, out_features=1000, bias=True)

I want to replace the last layer Linear(in_features=4096, out_features=1000, bias=True) with my own regression net. `


Solution

  • Since you're working with a reasonably small list (even ResNet-150 is "reasonably small" in RAM terms), I'd make this easy to understand and maintain. There is no "obvious" way to detect that you're one step short of exhausting a generator.

    1. Deplete the current generator, making a list of its output.
    2. Replace the last element as desired.
    3. Wrap a new generator around this altered list.

    The "nice" (?) way to do this is to write a wrapper generator with a one-element look-ahead in the original: at each call N, you already have element N in your wrapper. You grab element N+1 from the "real" generator (your posted code). If that element exists, then you return element N normally. If that generator is exhausted, then you replace this last element with the one you want, and return the alteration.

    EXAMPLE:

    TO keep this simple, I've used range in place of your original generator.

    def new_tail():
        my_list = list(range(6))
        my_list[-1] = "new last element"
        for elem in my_list:
            yield elem
    
    for item in new_tail():
        print(item)
    

    Output:

    0
    1
    2
    3
    4
    new last element
    

    Does that help?