Search code examples
pythonnumpynumbaarray-broadcasting

Numba messes up dtype when broadcasting


I want to safe storage by using small dtypes. However when I add or multiply a number to an array numba changes the dtype to int64:

Pure Numpy

In:

def f():
    a=np.ones(10, dtype=np.uint8)
    return a+1
f()

Out:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

Now with numba:

In:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a+1
f()

Out:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)

One solution is to replace a+1 with a+np.ones(a.shape, dtype=a.dtype) but I cannot imagine something uglier.

Thanks a lot for help!


Solution

  • I guess the simplest thing is to just add two np.uint8:

    import numpy as np
    from numba import njit
    
    @njit
    def f():
        a=np.ones(10, dtype=np.uint8)
        return a + np.uint8(1)
    print(f().dtype)
    

    Output:

    uint8
    

    I find this more elegant than changing the type of the full array or working with np.ones or np.full.