Search code examples
pythonnumpymatrix-multiplicationnumba

How to use numba?


I meet a problem about using numba jit decorator (@nb.jit)! Here is the warning from jupyter notebook,

NumbaWarning: Compilation is falling back to object mode WITH looplifting enabled because Function "get_nb_freq" failed type inference due to: No implementation of function Function(<function dotenter image description here at 0x00000190AC399B80>) found for signature:

This is complete information:

results

Here is my code:

@numba.jit
def get_nb_freq( nb_count = None, onehot_ct = None):
#     nb_freq = onehot_ct.T @ nb_count
    nb_freq = np.dot(onehot_ct.T, nb_count)
    res = nb_freq/nb_freq.sum(axis = 1).reshape(Num_celltype,-1)
    return res

## onehot_ct is array, and its shape is (921600,4)
## nb_count is array, and its shape is the same as onehot_ct
## Num_celltype  equals 4

Solution

  • Based on your mentioned shapes we can create the arrays as:

    onehot_ct = np.random.rand(921600, 4)
    nb_count = np.random.rand(921600, 4)
    

    Your prepared code will be work correctly and get an answer like:

    [[0.25013102754197963 0.25021461207825463 0.2496806287276126  0.24997373165215303]
     [0.2501574139037384  0.25018726649940737 0.24975108864220968 0.24990423095464467]
     [0.25020550587624757 0.2501303498983212  0.24978335463279314 0.24988078959263807]
     [0.2501855533482036  0.2500913419625523  0.24979681404573967 0.24992629064350436]]
    

    So, it shows the code is working and the problem seems to be related to type of the arrays, that numba can not recognize them. So, signature may be helpful here, which by we can recognize the types manually for the function. So, based on the error I think the following signature will pass your issue:

    @nb.jit("float64[:, ::1](float64[:, ::1], float32[:, ::1])")
    def get_nb_freq( nb_count = None, onehot_ct = None):
        nb_freq = np.dot(onehot_ct.T, nb_count)
        res = nb_freq/nb_freq.sum(axis=1).reshape(4, -1)
        return res
    

    But it will stuck again if you test by get_nb_freq(nb_count.astype(np.float64), onehot_ct.astype(np.float32)), So another cause could be related to unequal types in np.dot. So, use the onehot_ct array as array type np.float64, could pass the issue:

    @nb.jit("float64[:, ::1](float64[:, ::1], float32[:, ::1])")
    def get_nb_freq( nb_count, onehot_ct):
        nb_freq = np.dot(onehot_ct.astype(np.float64).T, nb_count)
        res = nb_freq/nb_freq.sum(axis=1).reshape(4, -1)
        return res
    

    It ran on my machine with this correction. I recommend write numba equivalent codes (like this for np.dot) instead using np.dot or …, which can be much faster.