I am getting errors trying to run numpy.dot
with numba
. It seems to be supported (eg: numpy: Faster np.dot/ multiply(element-wise multiplication) when one array is the same) but eg this code gives me the following error (it runs fine if I remove the njit part)
Code:
import numpy as np
import numba
@numba.njit()
def tst_dot():
a = np.array([[1, 0], [0, 1]])
b = np.array([[4, 1], [2, 2]])
return np.dot(a, b)
print(tst_dot())
Error:
No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
>>> dot(array(int64, 2d, C), array(int64, 2d, C))
There are 4 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: native lowering)
Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
>>> dot(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))
There are 4 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
Rejected as the implementation raised a specific error:
TypingError: too many positional arguments
raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784
- Of which 2 did not match due to:
Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
Rejected as the implementation raised a specific error:
LoweringError: Failed in nopython mode pipeline (step: native lowering)
unsupported dtype for <BLAS function>()
File "venv\lib\site-packages\numba\np\linalg.py", line 817:
def codegen(context, builder, sig, args):
<source elided>
return lambda left, right, out: _impl(left, right, out)
^
During: lowering "$10call_function.4 = call $2load_deref.0(left, right, out, func=$2load_deref.0, args=[Var(left, linalg.py:817), Var(right, linalg.py:817), Var(out, linalg.py:817)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (817)
raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\errors.py:837
During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (460)
File "venv\lib\site-packages\numba\np\linalg.py", line 460:
def dot_impl(a, b):
<source elided>
out = np.empty((m, n), a.dtype)
return np.dot(a, b, out)
^
During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (582)
raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typeinfer.py:1086
- Of which 2 did not match due to:
Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
Rejected as the implementation raised a specific error:
TypingError: missing a required argument: 'out'
raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784
During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\tst4.py (164)
File "tst4.py", line 164:
def tst_dot(a, b):
<source elided>
return np.dot(a, b)
^
I have tried adding out=None
as a third argument (even though it is meant to be optional) but it didn't help. I was expecting the same result as if I was not using numba
.
The docs
say:
Basic linear algebra is supported on 1-D and 2-D contiguous arrays of floating-point and complex numbers:
numpy.dot()
- ...
However, your two arrays contain integers. Note indeed, the error message:
dot(array(int64, 2d, C), array(int64, 2d, C))
Hence, the trick is to change the dtype
:
import numpy as np
import numba
@numba.njit()
def tst_dot():
a = np.array([[1, 0], [0, 1]], dtype=np.float32)
b = np.array([[4, 1], [2, 2]], dtype=np.float32)
return np.dot(a, b)
print(tst_dot())
[[4. 1.]
[2. 2.]]