Search code examples
pythonlistperformancelist-comprehension

set() inside list comprehension inefficient?


I'd like to compute the difference between two lists while maintaining ordering. ("stable" removal):

Compute a list that has all values from list contained in b removed from list a while maintaining a's order in the result .

I found comments that this solution is inefficient because of using set in a list comprehension:

def diff1(a, b): 
    return [x for x in a if x not in set(b)]

And that this is more efficient:

def diff2(a, b):
    bset = set(b)
    return [x for x in a if x not in bset]

I tried to verify it empirically. On the face of it it seems like it is true. I tried to test it using iPython:

mysetup='''
a = [1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19]
 
b=[3,8,9,10]
'''

s1='def diff1(a, b): return [x for x in a if x not in set(b)]'

s2='''
def diff2(a, b):
    bset = set(b)
    return [x for x in a if x not in bset]
'''

Result:

In [80]: timeit.timeit(setup=mysetup, stmt=s1, number=100000000)
Out[80]: 3.826001249952242

In [81]: timeit.timeit(setup=mysetup, stmt=s2, number=100000000)
Out[81]: 3.703278487082571

I've tested this several times and diff2 is consistently faster than diff1.

Why? Shouldn't set(b) inside list comprehension be computed once?


Solution

  • diff2 will always be faster than diff1, unless a == [] because set(b) needs to be evaluated on every iteration of the list comprehension.

    technically the list comprehension could be modifying b on every iteration so the inclusion check needs to be carried out on every iteration, which involves linear overhead for creating the set first

    For instance you could write:

    a=set([5])
    [a:=set(set(a)) for _ in range(50) if set(a)]
    

    Therefore the list comprehension needs to re-evaluate its condition every time.

    Profiling this on a larger sample, confirms this observation:

    import random
    
    def generate_lists(N):
        """ Generate some random lists with some overlap and some unique elements """
        common_elements = random.sample(range(N), N//2)  
        a = random.sample(range(N*2), N)  
        b = random.sample(common_elements + random.sample(range(N*2), N//2), N)  
        return a, b
    
    def reevaluate(a, b): 
        return [x for x in a if x not in set(b)]
    
    def precompute(a, b):
        bset = set(b)
        return [x for x in a if x not in bset]
    
    def no_set_usage(a,b):
        return [x for x in a if x not in b]
    
    data_size = [1000,2000,3000,4000,5000]
    
    approaches = [
        reevaluate,
        precompute,
        no_set_usage
    ]
    run_performance_comparison(approaches, data_size, setup=generate_lists)
    

    enter image description here enter image description here

    We can clearly see that no_set_usage is in the superlinear regime, probably in O(n**2). We can verify this with a sqrt plot. To see the O(N) scaling of precompute, we need to turn to a log/log plot:

    enter image description here enter image description here

    Profiling code:

    import timeit
    import shutil
    from pathlib import Path
    import matplotlib.pyplot as plt
    from typing import List, Dict, Callable, ContextManager
    
    from contextlib import contextmanager
    
    
    @contextmanager
    def data_provider(data_size, setup=lambda N: N, teardown=lambda: None):
        data = setup(data_size)
        yield data
        teardown()
    
    
    def run_performance_comparison(approaches: List[Callable], data_size: List[int],
                                   setup=lambda N: N, teardown=lambda: None, number_of_repetitions=5, title='N'):
        approach_times: Dict[Callable, List[float]] = {approach: [] for approach in approaches}
    
        for N in data_size:
            with data_provider(N, setup, teardown) as data:
                for approach in approaches:
                    approach_time = timeit.timeit(lambda: approach(*data), number=number_of_repetitions)
                    approach_times[approach].append(approach_time)
    
        for approach in approaches:
            plt.plot(data_size, approach_times[approach], label=approach.__name__)
    
        plt.xlabel(title)
        plt.ylabel('Execution Time (seconds)')
        plt.title('Performance Comparison')
        plt.legend()
        plt.show()