22

Say we wish to process an iterator and want to handle it by chunks.
The logic per chunk depends on previously-calculated chunks, so groupby() does not help.

Our friend in this case is itertools.takewhile():

while True:
    chunk = itertools.takewhile(getNewChunkLogic(), myIterator)
    process(chunk)

The problem is that takewhile() needs to go past the last element that meets the new chunk logic, thus 'eating' the first element for the next chunk.

There are various solutions to that, including wrapping or à la C's ungetc(), etc..
My question is: is there an elegant solution?

Mazdak
  • 105,000
  • 18
  • 159
  • 188
Paul Oyster
  • 1,133
  • 1
  • 12
  • 21

4 Answers4

11

takewhile() indeed needs to look at the next element to determine when to toggle behaviour.

You could use a wrapper that tracks the last seen element, and that can be 'reset' to back up one element:

_sentinel = object()

class OneStepBuffered(object):
    def __init__(self, it):
        self._it = iter(it)
        self._last = _sentinel
        self._next = _sentinel
    def __iter__(self):
        return self
    def __next__(self):
        if self._next is not _sentinel:
            next_val, self._next = self._next, _sentinel
            return next_val
        try:
            self._last = next(self._it)
            return self._last
        except StopIteration:
            self._last = self._next = _sentinel
            raise
    next = __next__  # Python 2 compatibility
    def step_back(self):
        if self._last is _sentinel:
            raise ValueError("Can't back up a step")
        self._next, self._last = self._last, _sentinel

Wrap your iterator in this one before using it with takewhile():

myIterator = OneStepBuffered(myIterator)
while True:
    chunk = itertools.takewhile(getNewChunkLogic(), myIterator)
    process(chunk)
    myIterator.step_back()

Demo:

>>> from itertools import takewhile
>>> test_list = range(10)
>>> iterator = OneStepBuffered(test_list)
>>> list(takewhile(lambda i: i < 5, iterator))
[0, 1, 2, 3, 4]
>>> iterator.step_back()
>>> list(iterator)
[5, 6, 7, 8, 9]
Martijn Pieters
  • 1,048,767
  • 296
  • 4,058
  • 3,343
  • @KarolyHorvath: It depends on how your original iterator was coded if performance is killed. If everything else was coded in C, then yes, this adds a step back into the Python interpreter and that can affect performance. The alternative is to re-tool your algorithm to not rely on `takewhile()`. Without details as to what you are doing that's not something I can help with at this point. – Martijn Pieters Jun 03 '15 at 09:52
  • @MartijnPieters: yes, that's what I meant. I'm not doing anything, the OP is someone else ;) As for the `ValueError`: it makes perfect sense from your `OneStepBuffered`'s viewpoint, but for OP's task, that's looks like a bug to me. – Karoly Horvath Jun 03 '15 at 09:57
  • well, assuming that it's a valid scenario.... note: you can catch the exception if that's the case – Karoly Horvath Jun 03 '15 at 10:03
  • assume an endless, but fast, iterator. – Paul Oyster Jun 03 '15 at 10:03
4

I had the same problem. You might wish to use itertools.tee or itertools.pairwise (new in Python 3.10) to deal with this, but I didn't think those solutions were very elegant.

The best I found is to just rewrite takewhile. Based heavily on the documentation:

def takewhile_inclusive(predicate, it):
  for x in it:
    if predicate(x):
      yield x
    else:
      yield x
      break

In your loop you can elegantly treat the final element separately using unpacking:

*chunk,lastPiece = takewhile_inclusive(getNewChunkLogic(), myIterator)

You can then chain the last piece:

lastPiece = None
while True:
  *chunk,lastPiece = takewhile_inclusive(getNewChunkLogic(), myIterator)
  if lastPiece is not None:
    myIterator = itertools.chain([lastPiece], myIterator))
  
Duncan
  • 209
  • 1
  • 6
  • You could shorten that to `for x in it:` `yield x` `if not predicate(x):` `break`. – Kelly Bundy Apr 08 '23 at 18:55
  • If the input iterator is "empty", your `*chunk,lastPiece = ...` crashes. – Kelly Bundy Apr 08 '23 at 19:01
  • Your initial `lastPiece = None` is pointless, that value is never used. And I can't make sense of your `if lastPiece is not None`. Why are you removing `None` values from the input iterator? And why only if they're right after a chunk? – Kelly Bundy Apr 08 '23 at 19:04
  • Nested chains can become slow, as all elements get moved through the whole stack of chain iterators. – Kelly Bundy Apr 08 '23 at 19:07
  • @KellyBund the problem OP had was takewhile "'eating' the first element for the next chunk". So the `if lastPiece is not None` part of the code is pushing the last piece back in the front of the iterator, so it isn't 'eaten'. I like your suggestion to shorten it, I just wrote it in this verbose way to mirror the linked documentation. – Duncan Apr 09 '23 at 20:14
  • But why are you looking for `None`? Why not `if lastPiece != 42`? What makes `None` special here? – Kelly Bundy Apr 09 '23 at 20:20
0

Given the callable GetNewChunkLogic() will report True along first chunk and False afterward.
The following snippet

  1. solves the 'additional next step' problem of takewhile .
  2. is elegant because you don't have to implement the back-one-step logic .

def partition(pred, iterable):
    'Use a predicate to partition entries into true entries and false entries'
    # partition(is_odd, range(10)) -->  1 3 5 7 9 and 0 2 4 6 8
    t1, t2 = tee(iterable)
    return filter(pred, t1), filterfalse(pred, t2)

while True:
    head, tail = partition(GetNewChunkLogic(), myIterator)
    process(head)
    myIterator = tail

However, the most elegant way is to modify your GetNewChunkLogic into a generator and remove the while loop.

WeiChing 林煒清
  • 4,452
  • 3
  • 30
  • 65
  • This is very bad. Even if the current chunk logic does keep reporting only `False` afterwards, never `True` again (which isn't guaranteed), this is very problematic. Every `head` always runs all the way to the end of the `myIterator`, which is slow and loads it all into memory. And filtering values through the whole stack of `filterfalse` iterators that you build is slow, too. – Kelly Bundy Apr 08 '23 at 19:54
0

Here's another way you can do it. Yield a value (sentinel) when the predicate fails, but before yielding the value itself. Then group by values which aren't the sentinel.

Here, group_by_predicate requires a function that returns a predicate (pred_gen). This is recreated every time the predicate fails:

from itertools import groupby


def group_by_predicate(predicate_gen, _iter):
    sentinel = object()
    def _group_with_sentinel():
        pred = predicate_gen()
        for n in _iter:
            while not pred(n):
                yield sentinel
                pred = predicate_gen()
            yield n
            
    g = _group_with_sentinel()
    for k, g in groupby(g, lambda s: s!=sentinel):
        if k:
            yield g

This can then be used like:

def less_than_gen(maxn):
    """Return a predicate that returns true while the sum of inputs is < maxn"""
    def pred(i):
        pred.count += i
        return pred.count < maxn
    pred.count = 0
    return pred

data = iter(list(range(9)) * 3)

for g in group_by_predicate(lambda: less_than_gen(15), data):
    print(list(g))

Which outputs groups of numbers whose sum is all less than 15:

[0, 1, 2, 3, 4]
[5, 6]
[7]
[8, 0, 1, 2, 3]
[4, 5]
[6, 7]
[8, 0, 1, 2, 3]
[4, 5]
[6, 7]
[8]
MattM
  • 919
  • 8
  • 9