Search code examples
pythonnumpynumba

numba.core.errors.TypingError: while using np.random.randint()


How to use np.random.randint with numba as this throws a very large error, https://hastebin.com/kodixazewo.sql

from numba import jit
import numpy as np
@jit(nopython=True)
def foo():
    a = np.random.randint(16, size=(3,3))
    return a
foo()

Solution

  • You can use np.ndindex to loop over your desired output size and call np.random.randint for each element individually.

    Make sure the output datatype is sufficient to support the range of integers from the randint call.

    from numba import njit
    import numpy as np
    
    @njit
    def foo(size=(3,3)):
        
        out = np.empty(size, dtype=np.uint16)
            
        for idx in np.ndindex(size): 
            out[idx] = np.random.randint(16)
            
        return out
    

    This makes it work for any arbitrary shape:

    foo(size=(2,2,2))
    

    Results in:

    array([[[ 8,  7],
            [15,  2]],
    
           [[ 4, 13],
            [ 5, 11]]], dtype=uint16)