I want to safe storage by using small dtypes. However when I add or multiply a number to an array numba changes the dtype to int64:
Pure Numpy
In:
def f():
a=np.ones(10, dtype=np.uint8)
return a+1
f()
Out:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
Now with numba:
In:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a+1
f()
Out:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)
One solution is to replace a+1 with a+np.ones(a.shape, dtype=a.dtype) but I cannot imagine something uglier.
Thanks a lot for help!
I guess the simplest thing is to just add two np.uint8
:
import numpy as np
from numba import njit
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.uint8(1)
print(f().dtype)
Output:
uint8
I find this more elegant than changing the type of the full array or working with np.ones
or np.full
.