Search code examples
pythonpython-3.xperformance

Repeat elements in nested lists each a different number of times, why smarter methods are slower?


I had seen this question today, and clearly the asker didn't show any research effort at all. But someone posted an answer, the code in the answer was very straightforward and verbose, so I wanted to post a more concise and elegant solution, and I wanted the smarter method to be faster.

To save you a click, the problem is, given a list of lists, and another list of the same number of lists as the first one, each sublist in the second nested list contains only integers, and all sublist of the second list contain the same number of elements as the first list, assume they are different, repeat the each last level element in the first nested the corresponding element in the second list times.

Example:

data = ([2, 0, 2, 2],
 [3, 3, 1, 2],
 [1, 0, 3, 3],
 [1, 1, 1, 2],
 [0, 0, 2, 1],
 [0, 1, 3, 3],
 [3, 1, 3, 2],
 [1, 0, 1, 2])

mult = ([3, 0, 0, 3],
 [2, 2, 1, 1],
 [0, 2, 2, 1],
 [3, 3, 3, 2],
 [0, 2, 3, 2],
 [1, 1, 3, 2],
 [3, 1, 2, 3],
 [3, 2, 0, 0])

output = deque([[2, 2, 2, 2, 2, 2],
       [3, 3, 3, 3, 1, 2],
       [0, 0, 3, 3, 3],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2],
       [0, 0, 2, 2, 2, 1, 1],
       [0, 1, 3, 3, 3, 3, 3],
       [3, 3, 3, 1, 3, 3, 2, 2, 2],
       [1, 1, 1, 0, 0]])

I quickly came up with a list comprehension solution:

def repeat_element_listcomp(data, mult):
    return [[i for i, j in zip(a, b) for _ in range(j)] for a, b in zip(data, mult)]

But I was surprised to find it slower than the simple solution of the first answer:

def repeat_element_zero(data, mult):
    combined = []
    for sublist1, sublist2 in zip(data, mult):
        sublist = []
        for elem1, elem2 in zip(sublist1, sublist2):
            sublist.extend([elem1]* elem2)

        combined.append(sublist)
    
    return combined
In [229]: %timeit repeat_element_zero(data, mult)
9.86 µs ± 129 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [230]: %timeit repeat_element(data, mult)
14.1 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

I wasted many minutes trying to come up with a more efficient solution, I tried many more smart methods, and all of them are slower somehow, I then posted an answer there.

Setup:

import random
from collections import deque
from functools import reduce
from itertools import chain
from operator import iconcat

def random_numbers(n):
    return random.choices(range(n), k=n)

def make_data(n):
    return random_numbers(n), random_numbers(n)

def make_sample(n, limit=300):
    return list(zip(*[make_data(limit) for _ in range(n)]))

def repeat_element_zero(data, mult):
    combined = []
    for sublist1, sublist2 in zip(data, mult):
        sublist = []
        for elem1, elem2 in zip(sublist1, sublist2):
            sublist.extend([elem1]* elem2)

        combined.append(sublist)
    
    return combined

def repeat_element_listcomp(data, mult):
    return [[i for i, j in zip(a, b) for _ in range(j)] for a, b in zip(data, mult)]

def repeat_element_chain(data, mult):
    return [list(chain(*([i]*j for i, j in zip(a, b)))) for a, b in zip(data, mult)]

def repeat_element_helper(data, mult):
    return reduce(iconcat, ([i]*j for i, j in zip(data, mult)), [])

def repeat_element(data, mult):
    return deque(map(repeat_element_helper, data, mult))
    
approaches=[
    repeat_element_listcomp,
    repeat_element_chain,
    repeat_element
    
]
run_performance_comparison(approaches,[1000,2000,3000],setup=make_sample)

Performance:

enter image description here

In [188]: data, mult = make_sample(32, 10)

In [189]: %timeit repeat_element_zero(data, mult)
102 µs ± 3.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [190]: %timeit repeat_element_listcomp(data, mult)
145 µs ± 3.55 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [191]: %timeit repeat_element_chain(data, mult)
141 µs ± 4.74 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [192]: %timeit repeat_element(data, mult)
127 µs ± 1.4 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [193]: data, mult = make_sample(32, 32)

In [194]: %timeit repeat_element(data, mult)
576 µs ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [195]: %timeit repeat_element_chain(data, mult)
647 µs ± 16.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [196]: %timeit repeat_element_listcomp(data, mult)
837 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [197]: %timeit repeat_element_zero(data, mult)
465 µs ± 15.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [198]: data, mult = make_sample(256, 32)

In [199]: %timeit repeat_element_zero(data, mult)
3.69 ms ± 64.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [200]: %timeit repeat_element(data, mult)
4.47 ms ± 88.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [201]: %timeit repeat_element_listcomp(data, mult)
7.01 ms ± 688 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Profiling Code:

import timeit
import matplotlib.pyplot as plt
from typing import List, Dict, Callable

from contextlib import contextmanager


@contextmanager
def data_provider(data_size, setup=lambda N: N, teardown=lambda: None):
    data = setup(data_size)
    yield data
    teardown()


def run_performance_comparison(approaches: List[Callable],
                               data_size: List[int],
                               setup=lambda N: N,
                               teardown=lambda: None,
                               number_of_repetitions=5, title='N'):
    approach_times: Dict[Callable, List[float]] = {approach: [] for approach in approaches}

    for N in data_size:
        with data_provider(N, setup, teardown) as data:
            for approach in approaches:
                approach_time = timeit.timeit(lambda: approach(*data), number=number_of_repetitions)
                approach_times[approach].append(approach_time)

    for approach in approaches:
        plt.plot(data_size, approach_times[approach], label=approach.__name__)

    plt.xlabel(title)
    plt.ylabel('Execution Time (seconds)')
    plt.title('Performance Comparison')
    plt.legend()
    plt.show()

I want to know, why all of my smart methods slower? What is going on here? Why these methods which are normally what increase performance, are making the code slower in this case? I guess this must be implementation detail of CPython, and if version is important I am using Python 3.10.11 (tags/v3.10.11:7d4cc5a, Apr 5 2023, 00:38:17) [MSC v.1929 64 bit (AMD64)], I don't know C yet so I don't know the low-level stuff. But I am really curious and want an explanation.


Disassembly of the functions:

In [232]: import dis

In [233]: dis.dis(repeat_element_zero)
 11           0 BUILD_LIST               0
              2 STORE_FAST               2 (combined)

 12           4 LOAD_GLOBAL              0 (zip)
              6 LOAD_FAST                0 (data)
              8 LOAD_FAST                1 (mult)
             10 CALL_FUNCTION            2
             12 GET_ITER
        >>   14 FOR_ITER                29 (to 74)
             16 UNPACK_SEQUENCE          2
             18 STORE_FAST               3 (sublist1)
             20 STORE_FAST               4 (sublist2)

 13          22 BUILD_LIST               0
             24 STORE_FAST               5 (sublist)

 14          26 LOAD_GLOBAL              0 (zip)
             28 LOAD_FAST                3 (sublist1)
             30 LOAD_FAST                4 (sublist2)
             32 CALL_FUNCTION            2
             34 GET_ITER
        >>   36 FOR_ITER                12 (to 62)
             38 UNPACK_SEQUENCE          2
             40 STORE_FAST               6 (elem1)
             42 STORE_FAST               7 (elem2)

 15          44 LOAD_FAST                5 (sublist)
             46 LOAD_METHOD              1 (extend)
             48 LOAD_FAST                6 (elem1)
             50 BUILD_LIST               1
             52 LOAD_FAST                7 (elem2)
             54 BINARY_MULTIPLY
             56 CALL_METHOD              1
             58 POP_TOP
             60 JUMP_ABSOLUTE           18 (to 36)

 17     >>   62 LOAD_FAST                2 (combined)
             64 LOAD_METHOD              2 (append)
             66 LOAD_FAST                5 (sublist)
             68 CALL_METHOD              1
             70 POP_TOP
             72 JUMP_ABSOLUTE            7 (to 14)

 19     >>   74 LOAD_FAST                2 (combined)
             76 RETURN_VALUE

In [234]: dis.dis(repeat_element)
 31           0 LOAD_GLOBAL              0 (deque)
              2 LOAD_GLOBAL              1 (map)
              4 LOAD_GLOBAL              2 (repeat_element_helper)
              6 LOAD_FAST                0 (data)
              8 LOAD_FAST                1 (mult)
             10 CALL_FUNCTION            3
             12 CALL_FUNCTION            1
             14 RETURN_VALUE

In [235]: dis.dis(repeat_element_helper)
 28           0 LOAD_GLOBAL              0 (reduce)
              2 LOAD_GLOBAL              1 (iconcat)
              4 LOAD_CONST               1 (<code object <genexpr> at 0x000001C86CC20240, file "<ipython-input-225-c385a750c738>", line 28>)
              6 LOAD_CONST               2 ('repeat_element_helper.<locals>.<genexpr>')
              8 MAKE_FUNCTION            0
             10 LOAD_GLOBAL              2 (zip)
             12 LOAD_FAST                0 (data)
             14 LOAD_FAST                1 (mult)
             16 CALL_FUNCTION            2
             18 GET_ITER
             20 CALL_FUNCTION            1
             22 BUILD_LIST               0
             24 CALL_FUNCTION            3
             26 RETURN_VALUE

Disassembly of <code object <genexpr> at 0x000001C86CC20240, file "<ipython-input-225-c385a750c738>", line 28>:
              0 GEN_START                0

 28           2 LOAD_FAST                0 (.0)
        >>    4 FOR_ITER                10 (to 26)
              6 UNPACK_SEQUENCE          2
              8 STORE_FAST               1 (i)
             10 STORE_FAST               2 (j)
             12 LOAD_FAST                1 (i)
             14 BUILD_LIST               1
             16 LOAD_FAST                2 (j)
             18 BINARY_MULTIPLY
             20 YIELD_VALUE
             22 POP_TOP
             24 JUMP_ABSOLUTE            2 (to 4)
        >>   26 LOAD_CONST               0 (None)
             28 RETURN_VALUE

In [236]: dis.dis(repeat_element_listcomp)
 22           0 LOAD_CONST               1 (<code object <listcomp> at 0x000001C86CE6B7E0, file "<ipython-input-225-c385a750c738>", line 22>)
              2 LOAD_CONST               2 ('repeat_element_listcomp.<locals>.<listcomp>')
              4 MAKE_FUNCTION            0
              6 LOAD_GLOBAL              0 (zip)
              8 LOAD_FAST                0 (data)
             10 LOAD_FAST                1 (mult)
             12 CALL_FUNCTION            2
             14 GET_ITER
             16 CALL_FUNCTION            1
             18 RETURN_VALUE

Disassembly of <code object <listcomp> at 0x000001C86CE6B7E0, file "<ipython-input-225-c385a750c738>", line 22>:
 22           0 BUILD_LIST               0
              2 LOAD_FAST                0 (.0)
        >>    4 FOR_ITER                14 (to 34)
              6 UNPACK_SEQUENCE          2
              8 STORE_FAST               1 (a)
             10 STORE_FAST               2 (b)
             12 LOAD_CONST               0 (<code object <listcomp> at 0x000001C86489E550, file "<ipython-input-225-c385a750c738>", line 22>)
             14 LOAD_CONST               1 ('repeat_element_listcomp.<locals>.<listcomp>.<listcomp>')
             16 MAKE_FUNCTION            0
             18 LOAD_GLOBAL              0 (zip)
             20 LOAD_FAST                1 (a)
             22 LOAD_FAST                2 (b)
             24 CALL_FUNCTION            2
             26 GET_ITER
             28 CALL_FUNCTION            1
             30 LIST_APPEND              2
             32 JUMP_ABSOLUTE            2 (to 4)
        >>   34 RETURN_VALUE

Disassembly of <code object <listcomp> at 0x000001C86489E550, file "<ipython-input-225-c385a750c738>", line 22>:
 22           0 BUILD_LIST               0
              2 LOAD_FAST                0 (.0)
        >>    4 FOR_ITER                13 (to 32)
              6 UNPACK_SEQUENCE          2
              8 STORE_FAST               1 (i)
             10 STORE_FAST               2 (j)
             12 LOAD_GLOBAL              0 (range)
             14 LOAD_FAST                2 (j)
             16 CALL_FUNCTION            1
             18 GET_ITER
        >>   20 FOR_ITER                 4 (to 30)
             22 STORE_FAST               3 (_)
             24 LOAD_FAST                1 (i)
             26 LIST_APPEND              3
             28 JUMP_ABSOLUTE           10 (to 20)
        >>   30 JUMP_ABSOLUTE            2 (to 4)
        >>   32 RETURN_VALUE

In [237]: dis.dis(repeat_element_chain)
 25           0 LOAD_CONST               1 (<code object <listcomp> at 0x000001C86CC22600, file "<ipython-input-225-c385a750c738>", line 25>)
              2 LOAD_CONST               2 ('repeat_element_chain.<locals>.<listcomp>')
              4 MAKE_FUNCTION            0
              6 LOAD_GLOBAL              0 (zip)
              8 LOAD_FAST                0 (data)
             10 LOAD_FAST                1 (mult)
             12 CALL_FUNCTION            2
             14 GET_ITER
             16 CALL_FUNCTION            1
             18 RETURN_VALUE

Disassembly of <code object <listcomp> at 0x000001C86CC22600, file "<ipython-input-225-c385a750c738>", line 25>:
 25           0 BUILD_LIST               0
              2 LOAD_FAST                0 (.0)
        >>    4 FOR_ITER                18 (to 42)
              6 UNPACK_SEQUENCE          2
              8 STORE_FAST               1 (a)
             10 STORE_FAST               2 (b)
             12 LOAD_GLOBAL              0 (list)
             14 LOAD_GLOBAL              1 (chain)
             16 LOAD_CONST               0 (<code object <genexpr> at 0x000001C86BF07260, file "<ipython-input-225-c385a750c738>", line 25>)
             18 LOAD_CONST               1 ('repeat_element_chain.<locals>.<listcomp>.<genexpr>')
             20 MAKE_FUNCTION            0
             22 LOAD_GLOBAL              2 (zip)
             24 LOAD_FAST                1 (a)
             26 LOAD_FAST                2 (b)
             28 CALL_FUNCTION            2
             30 GET_ITER
             32 CALL_FUNCTION            1
             34 CALL_FUNCTION_EX         0
             36 CALL_FUNCTION            1
             38 LIST_APPEND              2
             40 JUMP_ABSOLUTE            2 (to 4)
        >>   42 RETURN_VALUE

Disassembly of <code object <genexpr> at 0x000001C86BF07260, file "<ipython-input-225-c385a750c738>", line 25>:
              0 GEN_START                0

 25           2 LOAD_FAST                0 (.0)
        >>    4 FOR_ITER                10 (to 26)
              6 UNPACK_SEQUENCE          2
              8 STORE_FAST               1 (i)
             10 STORE_FAST               2 (j)
             12 LOAD_FAST                1 (i)
             14 BUILD_LIST               1
             16 LOAD_FAST                2 (j)
             18 BINARY_MULTIPLY
             20 YIELD_VALUE
             22 POP_TOP
             24 JUMP_ABSOLUTE            2 (to 4)
        >>   26 LOAD_CONST               0 (None)
             28 RETURN_VALUE

I barely understand any of the commands.


Solution

  • If you benchmark the individual components (comprehension vs multiplication) and (comprehension vs multiple .extend), you will see that comprehension are not always faster than low level code that processes multiple items at a time.

    from timeit import timeit
    
    t0  = timeit(lambda:[3 for _ in range(1000)], number = 1000)
    t1  = timeit(lambda:[3]*1000, number = 1000)
    
    print(t0) # 0.0317
    print(t1) # 0.0018
    
    source = [1,2,3]*100
    t0  = timeit(lambda:[n for _ in range(10) for n in source], number = 1000)
    t1  = timeit(lambda:[ any(L.extend(source) for L in [[]] for _ in range(10))], number = 1000)
    
    print(t0) # 0.0514
    print(t1) # 0.0072
    

    Results may vary depending on the sizes and multipliers but most of the time this should hold true.