Search code examples
pythonnumpynumba

Unable to use numpy.dot with numba


I am getting errors trying to run numpy.dot with numba. It seems to be supported (eg: numpy: Faster np.dot/ multiply(element-wise multiplication) when one array is the same) but eg this code gives me the following error (it runs fine if I remove the njit part)

Code:

import numpy as np
import numba

@numba.njit()
def tst_dot():
    a = np.array([[1, 0], [0, 1]])
    b = np.array([[4, 1], [2, 2]])

    return np.dot(a, b)

print(tst_dot())

Error:

No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
 
 >>> dot(array(int64, 2d, C), array(int64, 2d, C))
 
There are 4 candidate implementations:
      - Of which 2 did not match due to:
      Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
        With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: Failed in nopython mode pipeline (step: native lowering)
       Failed in nopython mode pipeline (step: nopython frontend)
       No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
        
        >>> dot(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))
        
       There are 4 candidate implementations:
             - Of which 2 did not match due to:
             Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
               With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
              Rejected as the implementation raised a specific error:
                TypingError: too many positional arguments
         raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784
             - Of which 2 did not match due to:
             Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
               With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
              Rejected as the implementation raised a specific error:
                LoweringError: Failed in nopython mode pipeline (step: native lowering)
              unsupported dtype for <BLAS function>()
              
              File "venv\lib\site-packages\numba\np\linalg.py", line 817:
                          def codegen(context, builder, sig, args):
                              <source elided>
              
                      return lambda left, right, out: _impl(left, right, out)
                      ^
              
              During: lowering "$10call_function.4 = call $2load_deref.0(left, right, out, func=$2load_deref.0, args=[Var(left, linalg.py:817), Var(right, linalg.py:817), Var(out, linalg.py:817)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (817)
         raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\errors.py:837
       
       During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
       During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (460)
       
       
       File "venv\lib\site-packages\numba\np\linalg.py", line 460:
           def dot_impl(a, b):
               <source elided>
               out = np.empty((m, n), a.dtype)
               return np.dot(a, b, out)
               ^
       
       During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (582)
  raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typeinfer.py:1086
      - Of which 2 did not match due to:
      Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
        With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: missing a required argument: 'out'
  raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784

During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\tst4.py (164)


File "tst4.py", line 164:
def tst_dot(a, b):
    <source elided>

    return np.dot(a, b)
    ^

I have tried adding out=None as a third argument (even though it is meant to be optional) but it didn't help. I was expecting the same result as if I was not using numba.


Solution

  • The docs say:

    Basic linear algebra is supported on 1-D and 2-D contiguous arrays of floating-point and complex numbers:

    • numpy.dot()
    • ...

    However, your two arrays contain integers. Note indeed, the error message:

    dot(array(int64, 2d, C), array(int64, 2d, C))
    

    Hence, the trick is to change the dtype:

    import numpy as np
    import numba
    
    @numba.njit()
    def tst_dot():
        a = np.array([[1, 0], [0, 1]], dtype=np.float32)
        b = np.array([[4, 1], [2, 2]], dtype=np.float32)
    
        return np.dot(a, b)
    
    print(tst_dot())
    
    [[4. 1.]
     [2. 2.]]