How to use np.random.randint
with numba as this throws a very large error, https://hastebin.com/kodixazewo.sql
from numba import jit
import numpy as np
@jit(nopython=True)
def foo():
a = np.random.randint(16, size=(3,3))
return a
foo()
You can use np.ndindex
to loop over your desired output size and call np.random.randint
for each element individually.
Make sure the output datatype is sufficient to support the range of integers from the randint call.
from numba import njit
import numpy as np
@njit
def foo(size=(3,3)):
out = np.empty(size, dtype=np.uint16)
for idx in np.ndindex(size):
out[idx] = np.random.randint(16)
return out
This makes it work for any arbitrary shape:
foo(size=(2,2,2))
Results in:
array([[[ 8, 7],
[15, 2]],
[[ 4, 13],
[ 5, 11]]], dtype=uint16)