Search code examples
pythonlistlist-comprehension

Filter out everything before a condition is met, keep all elements after


I was wondering if there was an easy solution to the the following problem. The problem here is that I want to keep every element occurring inside this list after the initial condition is true. The condition here being that I want to remove everything before the condition that a value is greater than 18 is true, but keep everything after. Example

Input:

p = [4,9,10,4,20,13,29,3,39]

Expected output:

p = [20,13,29,3,39]

I know that you can filter over the entire list through

[x for x in p if x>18] 

But I want to stop this operation once the first value above 18 is found, and then include the rest of the values regardless if they satisfy the condition or not. It seems like an easy problem but I haven't found the solution to it yet.


Solution

  • You could use enumerate and list slicing in a generator expression and next:

    out = next((p[i:] for i, item in enumerate(p) if item > 18), [])
    

    Output:

    [20, 13, 29, 3, 39]
    

    In terms of runtime, it depends on the data structure.

    The plots below show the runtime difference among the answers on here for various lengths of p.

    If the original data is a list, then using a lazy iterator as proposed by @Kelly Bundy is the clear winner:

    enter image description here

    But if the initial data is a ndarray object, then the vectorized operations as proposed by @richardec and @0x263A (for large arrays) are faster. In particular, numpy beats list methods regardless of array size. But for very large arrays, pandas starts to perform better than numpy (I don't know why, I (and I'm sure others) would appreciate it if anyone can explain it).

    enter image description here

    Code used to generate the first plot:

    import perfplot
    import numpy as np
    import pandas as pd
    import random
    from itertools import dropwhile
    
    def it_dropwhile(p):
        return list(dropwhile(lambda x: x <= 18, p))
    
    def walrus(p):
        exceeded = False
        return [x for x in p if (exceeded := exceeded or x > 18)]
    
    def explicit_loop(p):
        for i, x in enumerate(p):
            if x > 18:
                output = p[i:]
                break
        else:
            output = []
        return output
    
    def genexpr_next(p):
        return next((p[i:] for i, item in enumerate(p) if item > 18), [])
    
    def np_argmax(p):
        return p[(np.array(p) > 18).argmax():]
    
    def pd_idxmax(p):
        s = pd.Series(p)
        return s[s.gt(18).idxmax():]
    
    def list_index(p):
        for x in p:
            if x > 18:
                return p[p.index(x):]
        return []
    
    def lazy_iter(p):
        it = iter(p)
        for x in it:
            if x > 18:
                return [x, *it]
        return []
    
    perfplot.show(
        setup=lambda n: random.choices(range(0, 15), k=10*n) + random.choices(range(-20,30), k=10*n),
        kernels=[it_dropwhile, walrus, explicit_loop, genexpr_next, np_argmax, pd_idxmax, list_index, lazy_iter],
        labels=['it_dropwhile','walrus','explicit_loop','genexpr_next','np_argmax','pd_idxmax', 'list_index', 'lazy_iter'],
        n_range=[2 ** k for k in range(18)],
        equality_check=np.allclose,
        xlabel='~n/20'
    )
    

    Code used to generate the second plot (note that I had to modify list_index because numpy doesn't have index method):

    def list_index(p):
        for x in p:
            if x > 18:
                return p[np.where(p==x)[0][0]:]
        return []
    
    perfplot.show(
        setup=lambda n: np.hstack([np.random.randint(0,15,10*n), np.random.randint(-20,30,10*n)]),
        kernels=[it_dropwhile, walrus, explicit_loop, genexpr_next, np_argmax, pd_idxmax, list_index, lazy_iter],
        labels=['it_dropwhile','walrus','explicit_loop','genexpr_next','np_argmax','pd_idxmax', 'list_index', 'lazy_iter'],
        n_range=[2 ** k for k in range(18)],
        equality_check=np.allclose,
        xlabel='~n/20'
    )