Search code examples
pythonnumpyintel-mkl

How numpy arrays are overwritten from interpreter point of view?


I wrote two simple functions to learn CPython's behaviour regarding numpy arrays.

Python 3.12.1 and numpy version 1.26.2, compiled by mkl (conda default)

def foo():
    for i in range(100):
        H = np.random.rand(1000, 1000)

%timeit -r 100 foo()

def baaz():
    H = np.zeros((1000, 1000))
    for i in range(100):
        H[:, :] = np.random.rand(1000, 1000)

%timeit -r 100 baaz()

Using dis library to see the bytecodes by calling dis.dis(foo()) an dis.dis(baaz()) I get these two outputs.

Initially, I believed that baaz() should run faster than than foo() since we are reusing the H array, instead of de-allocating, and allocating H again on each loop. However, I see consistently that foo() is faster. I am wondering what causes this. I cannot read assembly bytecode, but by simply looking at dis.dis(foo()) and dis.dis(baaz()) output, I can see that foo() generates 13 extra lines compare to baaz().

Dissembled foo():


    671           0 RETURN_GENERATOR
                  2 POP_TOP
                  4 RESUME                   0
    
    676           6 LOAD_CONST               1 (None)
                  8 STORE_FAST               1 (lastline)
    
    677          10 LOAD_FAST                0 (code)
        -->      12 LOAD_ATTR                1 (NULL|self + co_lines)
                 32 CALL                     0
                 40 GET_ITER
            >>   42 FOR_ITER                23 (to 92)
                 46 UNPACK_SEQUENCE          3
                 50 STORE_FAST               2 (start)
                 52 STORE_FAST               3 (end)
                 54 STORE_FAST               4 (line)
    
    678          56 LOAD_FAST                4 (line)
                 58 POP_JUMP_IF_NOT_NONE     1 (to 62)
                 60 JUMP_BACKWARD           10 (to 42)
            >>   62 LOAD_FAST                4 (line)
                 64 LOAD_FAST                1 (lastline)
                 66 COMPARE_OP              55 (!=)
                 70 POP_JUMP_IF_TRUE         1 (to 74)
                 72 JUMP_BACKWARD           16 (to 42)
    
    679     >>   74 LOAD_FAST                4 (line)
                 76 STORE_FAST               1 (lastline)
    
    680          78 LOAD_FAST                2 (start)
                 80 LOAD_FAST                4 (line)
                 82 BUILD_TUPLE              2
                 84 YIELD_VALUE              1
                 86 RESUME                   1
                 88 POP_TOP
                 90 JUMP_BACKWARD           25 (to 42)
    
    677     >>   92 END_FOR
    
    681          94 RETURN_CONST             1 (None)
            >>   96 CALL_INTRINSIC_1         3 (INTRINSIC_STOPITERATION_ERROR)
                 98 RERAISE                  1
    ExceptionTable:
      4 to 58 -> 96 [0] lasti
      62 to 70 -> 96 [0] lasti
      74 to 94 -> 96 [0] lasti

Dissembled baaz():


  1           0 RESUME                   0

  2           2 LOAD_GLOBAL              0 (np)
             12 LOAD_ATTR                3 (NULL|self + zeros)
             32 LOAD_CONST               1 ((1000, 1000))
             34 CALL                     1
             42 STORE_FAST               0 (H)

  3          44 LOAD_GLOBAL              5 (NULL + range)
             54 LOAD_CONST               2 (100)
             56 CALL                     1
             64 GET_ITER
        >>   66 FOR_ITER                43 (to 156)
             70 STORE_FAST               1 (i)

  4          72 LOAD_GLOBAL              0 (np)
             82 LOAD_ATTR                6 (random)
            102 LOAD_ATTR                9 (NULL|self + rand)
            122 LOAD_CONST               3 (1000)
            124 LOAD_CONST               3 (1000)
            126 CALL                     2
            134 LOAD_FAST                0 (H)
            136 LOAD_CONST               0 (None)
            138 LOAD_CONST               0 (None)
            140 BUILD_SLICE              2
            142 LOAD_CONST               0 (None)
            144 LOAD_CONST               0 (None)
            146 BUILD_SLICE              2
            148 BUILD_TUPLE              2
            150 STORE_SUBSCR
            154 JUMP_BACKWARD           45 (to 66)

  3     >>  156 END_FOR
            158 RETURN_CONST             0 (None)

P.S: It may seem not obvious why one would think that baaz() should be faster, but this is indeed the case in a language like Julia Understanding Julia multi-thread / multi-process design.


Solution

  • In both cases you are creating a new array when you do np.random.rand(1000, 1000), and then de-allocating it. In the baaz case, you are also going through the work up updating the initial array. Hence it is slower.

    Numpy functions provide a way to avoid this, consider a simple case:

    arr[:] = arr + 1
    

    This always creates a new array, which is the result of the expression arr + 1. You could avoid this by using:

    np.add(arr, 1, out=arr)
    

    Just a quick example of the above:

    In [31]: %%timeit -r 100 arr = np.zeros(1_000_000)
        ...: arr[:] = arr + 1
        ...:
        ...:
    1.85 ms ± 375 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)
    
    In [32]: %%timeit -r 100 arr = np.zeros(1_000_000)
        ...: np.add(arr, 1, out=arr)
        ...:
        ...:
    418 µs ± 29.1 µs per loop (mean ± std. dev. of 100 runs, 1,000 loops each)
    

    Unfortunately, I don't think there is anything equivalent for numpy.random functions. Possibly, numba can help you here, not sure how optimized np.random is with it though. But it's worth taking a look at.