Search code examples
pythonlist-comprehensionpython-itertoolsflatten

Why is itertools.chain faster than a flattening list comprehension?


In the context of a discussion in the comments of this question it was mentioned that while concatenating a sequence of strings simply takes ''.join([str1, str2, ...]), concatenating a sequence of lists would be something like list(itertools.chain(lst1, lst2, ...)), although you can also use a list comprehension like [x for y in [lst1, lst2, ...] for x in y]. What surprised me is that the first method is consistently faster than the second:

import random
import itertools

random.seed(100)
lsts = [[1] * random.randint(100, 1000) for i in range(1000)]

%timeit [x for y in lsts for x in y]
# 39.3 ms ± 436 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit list(itertools.chain.from_iterable(lsts))
# 30.6 ms ± 866 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit list(x for y in lsts for x in y)  # Proposed in comments
# 62.5 ms ± 504 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Loop-based methods proposed in the comments
%%timeit
a = []
for lst in lsts: a += lst
# 26.4 ms ± 634 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
a = []
for lst in lsts: a.extend(lst)
# 26.7 ms ± 728 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

It is not a difference of orders of magnitude, but it is not negligible either. I was wondering how that might be the case, since list comprehensions are frequently among the fastest methods to solve a given problem. At first I thought that maybe the itertools.chain object would have a len that the list constructor could use to preallocate the necessary memory, but that is not the case (cannot call len on itertools.chain objects). Is some custom itertools.chain-to-list conversion taking place somehow or is itertools.chain taking advantage of some other mechanism?

Tested in Python 3.6.3 on Windows 10 x64, if that is relevant.

EDIT:

It seems the fastest method after all is calling .extending an empty list with each list, as proposed by @zwer, probably because it works on "chunks" of data instead of on a per-element basis.


Solution

  • Here is itertools.chain.from_iterable. It's not hard to read even if you don't know C and you can tell everything is happening at the c level (before being used to generate a list in your code).

    The bytecode for list comprehensions is like so:

    def f(lsts):
        return [x for y in lsts for x in y]
    
    dis.dis(f.__code__.co_consts[1])
      2           0 BUILD_LIST               0
                  2 LOAD_FAST                0 (.0)
            >>    4 FOR_ITER                18 (to 24)
                  6 STORE_FAST               1 (y)
                  8 LOAD_FAST                1 (y)
                 10 GET_ITER
            >>   12 FOR_ITER                 8 (to 22)
                 14 STORE_FAST               2 (x)
                 16 LOAD_FAST                2 (x)
                 18 LIST_APPEND              3
                 20 JUMP_ABSOLUTE           12
            >>   22 JUMP_ABSOLUTE            4
            >>   24 RETURN_VALUE
    

    These are all the python interpreter operations involved in creating a list comprehension. Just having all the operations at the C level (in chain) rather than having the interpreter step over each byte code step (in the comprehension) is what will give you that performance boost.

    Still, that boost is so small I wouldn't worry about it. This is python, readability over speed.


    Edit:

    For a list wrapped generator comprehension

    def g(lists):
        return list(x for y in lsts for x in y)
    
    # the comprehension
    dis.dis(g.__code__.co_consts[1])
      2           0 LOAD_FAST                0 (.0)
            >>    2 FOR_ITER                20 (to 24)
                  4 STORE_FAST               1 (y)
                  6 LOAD_FAST                1 (y)
                  8 GET_ITER
            >>   10 FOR_ITER                10 (to 22)
                 12 STORE_FAST               2 (x)
                 14 LOAD_FAST                2 (x)
                 16 YIELD_VALUE
                 18 POP_TOP
                 20 JUMP_ABSOLUTE           10
            >>   22 JUMP_ABSOLUTE            2
            >>   24 LOAD_CONST               0 (None)
                 26 RETURN_VALUE
    

    So the interpreter has a similar number of steps to go to in running the generator expression being unpacked by list, but as you would expect, the python level overhead of having list unpack a generator (as opposed to the C LIST_APPEND instruction) is what slows it down.

    dis.dis(f)
      2           0 LOAD_CONST               1 (<code object <listcomp> at 0x000000000FB58B70, file "<ipython-input-33-1d46ced34d66>", line 2>)
                  2 LOAD_CONST               2 ('f.<locals>.<listcomp>')
                  4 MAKE_FUNCTION            0
                  6 LOAD_FAST                0 (lsts)
                  8 GET_ITER
                 10 CALL_FUNCTION            1
                 12 RETURN_VALUE
    
    dis.dis(g)
      2           0 LOAD_GLOBAL              0 (list)
                  2 LOAD_CONST               1 (<code object <genexpr> at 0x000000000FF6F420, file "<ipython-input-40-0334a7cdeb8f>", line 2>)
                  4 LOAD_CONST               2 ('g.<locals>.<genexpr>')
                  6 MAKE_FUNCTION            0
                  8 LOAD_GLOBAL              1 (lsts)
                 10 GET_ITER
                 12 CALL_FUNCTION            1
                 14 CALL_FUNCTION            1
                 16 RETURN_VALUE