Search code examples
pythonlistdictionarylist-comprehension

Efficient Filtering of Lists in a Dictionary of Lists


I'm working with some reasonably large datasets (500,000 datapoints with 30 variables each) and would like to find the most efficient methods for filtering them.

For compatibility with existing code the data is structured as a dictionary of lists but can't be converted (e.g. to pandas DataFrame) and has to be filtered in situ.

Working example:

data = {'Param0':['x1','x2','x3','x4','x5','x6'],
        'Param1':['A','A','A','B','B','C'],
        'Param2': [100,200,150,80,90,50],
        'Param3': [20,60,40,30,30,5]}

# Param0 keys to keep
keep = ['x2', 'x4']

filtered = {k: [x for i, x in enumerate(v) if data['Param0'][i] in keep] for k, v in data.items()}

The result filtered gives the desired output, but this is very slow at scale.

Are there any quicker ways of doing this?


Solution

  • I would do:

    keep_idx = [i for i, v in enumerate(data['Param0']) if v in keep]
    filtered = {k: [v[i] for i in keep_idx] for k, v in data.items()}
    

    Timing

    import numpy as np
    from timeit import timeit
    
    # Solution in question
    def test_1(data, keep):
        return {
            k: [x for i, x in enumerate(v) if data['Param0'][i] in keep]
            for k, v in data.items()
        }
    
    # First solution from @I'mahdi
    def test_2(data, keep):
        keep_idx = [i for i, v in enumerate(data['Param0']) if v in keep]
        return {
            k: [val for i, val in enumerate(v) if i in keep_idx]
            for k, v in data.items()
        }
    
    # Second solution from @I'mahdi
    def test_3(data, keep):
        keep_idx = [i for i, v in enumerate(data['Param0']) if v in keep]
        return {k: list(np.asarray(v)[keep_idx]) for k, v in data.items()}
    
    # Solution in this answer
    def test_4(data, keep):
        keep_idx = [i for i, v in enumerate(data['Param0']) if v in keep]
        return {k: [v[i] for i in keep_idx] for k, v in data.items()}
    
    
    data = {f"Param{i}": list(range(10_000)) for i in range(20)}
    keep = list(range(0, 10_000, 100))
    
    print(test_1(data, keep) == test_2(data, keep))
    print(test_2(data, keep) == test_3(data, keep))
    print(test_3(data, keep) == test_4(data, keep))
    
    for i in range(1, 5):
        t = timeit(f"test_{i}(data, keep)", globals=globals(), number=10)
        print(f"Solution {i}: {t:.3f}")
    

    results in something like:

    Solution 1: 4.571
    Solution 2: 4.220
    Solution 3: 0.298
    Solution 4: 0.219