0

I have a generator and I want to modify the last element of the generator. I want to replace the last element with another element. I know how to retrieve the last element, but not how to modify it.

What would be the best way to approach this?

For more context, this is what I want to do:

for child in alexnet.children():
    for children_of_child in child.children():
         print(children_of_child);

My generator object is: children_of_child and for the second child all its children are:

Dropout(p=0.5)
Linear(in_features=9216, out_features=4096, bias=True)
ReLU(inplace)
Dropout(p=0.5)
Linear(in_features=4096, out_features=4096, bias=True)
ReLU(inplace)
Linear(in_features=4096, out_features=1000, bias=True)

I want to replace the last layer Linear(in_features=4096, out_features=1000, bias=True) with my own regression net. `

aa1
  • 783
  • 1
  • 15
  • 31

2 Answers2

2

Since you're working with a reasonably small list (even ResNet-150 is "reasonably small" in RAM terms), I'd make this easy to understand and maintain. There is no "obvious" way to detect that you're one step short of exhausting a generator.

  1. Deplete the current generator, making a list of its output.
  2. Replace the last element as desired.
  3. Wrap a new generator around this altered list.

The "nice" (?) way to do this is to write a wrapper generator with a one-element look-ahead in the original: at each call N, you already have element N in your wrapper. You grab element N+1 from the "real" generator (your posted code). If that element exists, then you return element N normally. If that generator is exhausted, then you replace this last element with the one you want, and return the alteration.

EXAMPLE:

TO keep this simple, I've used range in place of your original generator.

def new_tail():
    my_list = list(range(6))
    my_list[-1] = "new last element"
    for elem in my_list:
        yield elem

for item in new_tail():
    print(item)

Output:

0
1
2
3
4
new last element

Does that help?

Prune
  • 76,765
  • 14
  • 60
  • 81
  • I have completed steps the steps and for step 3 I have created a generator as follows: myGen = (n for n in myList). However, now that I try to assign this generator back to 'child', it gives me an error since I cannot 'call' this generator. How can I assign it back? – aa1 Jul 10 '18 at 18:07
  • 1
    Let me see about reproducing this. Python 3.4.5 is okay? – Prune Jul 10 '18 at 20:50
  • yes, I use 3.5 or 3.6 but I think it 3.4.5 should be fine too. thanks! – aa1 Jul 10 '18 at 23:24
1

The way to do this is to iterate one step ahead, keeping track of the previous value as you go. For each value, yield the previous one. When you get to the end, instead of yielding the last previous value, yield the replacement value:

def new_tail(it, tail):
    sentinel = prev = object()
    for value in it:
        if prev is not sentinel:
            yield prev
        prev = value
    yield tail

Or you can treat the first element specially instead of using a sentinel:

def new_tail(it, tail):
    it = iter(it)
    prev = next(it)
    for value in it:
        yield prev
        prev = value
    yield tail

You may want to think about what should happen with a completely empty iterator. I'm not sure whether you want to yield nothing, yield the replacement value, or raise an exception. The first version yields the replacement value; the second… well, it should raise an exception, but as of 3.7, it issues a DeprecationWarning and yields nothing, which is probably not the behavior you want.

Anyway, you can either use next with a sentinel default value, or except StopIteration: the next. Then it's easy to do any of the three you wanted.


But you can make this simpler if you think of it a little more abstractly: If you had all of the adjacent pairs of elements, the first element of each such pair gives you all but the last element. So, using the pairwise recipe from the itertools docs:

def new_tail(it, tail):
    for x, _ in pairwise(it):
        yield x
    yield tail

Or, if you prefer, you can even make it a single expression using itertools.chain and operator.itemgetter, although this is probably a bit silly:

def new_tail(it, tail):
    return chain(map(itemgetter(0), pairwise(it)), (tail,))
abarnert
  • 354,177
  • 51
  • 601
  • 671