I am writing a function that I want to optimize with Numba. The function performs operations on two vectors along with other operations with some scalars. When I set @jit's parallel argument to False everything runs fine but when it is set to True the returned result is just a vector of zeros. I've been playing with a toy version of the function to try and understand the problem but have made no progress and nothing I've seen in the Numba documentation or other forums has helped me understand the issue any better. Here is the toy function I've been playing with which recreates the behavior I've seen in the real function. I'm not familiar at all with how results from workers in Numba parallelized loops return their results but I don't see anything wrong with this toy function when I compare it to examples in the Numba documentation. I've also ensured that data types aren't an issue.
import numpy as np
from numba import jit, prange
@jit(nopython=True, parallel=False)
def test_loop(N, M, iters):
"""
Multiply two vectors of shape (N,1) and (1,M), sum over axis 1 and add them to result for each loop
"""
result = np.zeros((N, 1), dtype=np.float64) # Has shape (N,1), float64
for i in prange(iters):
A = np.reshape(np.arange(0.0, N, 1.0), (-1,1)) # Has shape (N, 1), float64
B = np.reshape(np.ones((M,1)), (1,-1)) # Has shape (1, M), float64
product = A * B # Has shape (N, M), float64
summation = np.reshape(np.sum(product, axis=1), (-1,1)) # Has shape (N, 1), float64
result += summation
return result
I've resolved the issue. It appears that numba doesn't like when np.reshape()
uses -1 for one of the shape arguments and the preallocated array result
isn't initialized to the correct shape in the compilation. When summation
is added to result
they don't broadcast because of the shape issue so result
remains unchanged. The code below works in both cases for parallel=True
or parallel=False
. Using N and M in the appropriate arguments in np.reshape()
resolves the issue. I don't know why this wouldn't throw an error or warning though when the loop is parallelized.
@njit(parallel=True)
def test_loop(N, M, iters):
"""
Multiply two vectors of shape (N,1) and (1,M), sum over axis 1 and add them to result for each loop
"""
result = np.zeros((N, 1), dtype=np.float64) # Has shape (N,1), float64
for i in prange(iters):
A = np.reshape(np.arange(0.0, N, 1.0), (N,1)) # Has shape (N, 1), float64
B = np.reshape(np.ones((M,1)), (1,M)) # Has shape (1, M), float64
product = A * B # Has shape (N, M), float64
summation = np.reshape(np.sum(product, axis=1), (N,1)) # Has shape (N, 1), float64
result += summation
return result