Search code examples
pythonpython-3.xiterable

"Pythonic" way to return elements from an iterable as long as a condition based on previous element is true


I am working on some code that needs to constantly take elements from an iterable as long as a condition based on (or related to) the previous element is true. For example, let's say I have a list of numbers:

lst = [0.1, 0.4, 0.2, 0.8, 0.7, 1.1, 2.2, 4.1, 4.9, 5.2, 4.3, 3.2]

And let's use a simple condition: the number does not differ from the previous number more than 1. So the expected output would be

[0.1, 0.4, 0.2, 0.8, 0.7, 1.1]

Normally, itertools.takewhile would be a good choice, but in this case it's a bit annoying because the first element doesn't have a previous element to query. The following code returns an empty list because for the first element the code queries the last element.

from itertools import takewhile
res1 = list(takewhile(lambda x: abs(lst[lst.index(x)-1] - x) <= 1., lst))
print(res1)
# []

I managed to write some "ugly" code to work around:

res2 = []
for i, x in enumerate(lst):
    res2.append(x)
    # Make sure index is not out of range
    if i < len(lst) - 1:
        if not abs(lst[i+1] - x) <= 1.:
            break
print(res2)
# [0.1, 0.4, 0.2, 0.8, 0.7, 1.1]

However, I feel like there should be more "pythonic" way to code this. Any suggestions?


Solution

  • You can write your own version of takewhile where the predicate takes both the current and previous values:

    def my_takewhile(iterable, predicate):
        iterable = iter(iterable)
        try:
            previous = next(iterable)
        except StopIteration:
            # next(iterable) raises if the iterable is empty
            return
        yield previous
        for current in iterable:
            if not predicate(previous, current):
                break
            yield current
            previous = current
    

    Example:

    >>> list(my_takewhile(lst, lambda x, y: abs(x - y) <= 1))
    [0.1, 0.4, 0.2, 0.8, 0.7, 1.1]