I am trying to speed up some code using numba
, but it is tough sledding. For example, the following function does not numba-fy,
@jit(nopython=True)
def returns(Ft, x, delta):
T = len(x)
rets = Ft[0:T - 1] * x[1:T] - delta * np.abs(Ft[1:T] - Ft[0:T - 1])
return np.concatenate([[0], rets])
because numba cannot find the signature of np.concatenate
. Is there a canonical fix for this?
A bit late, but I hope still useful. Since you asked for the "canonical fix", I would like to explain why concatenate
is a bad idea when working with arrays and especially if you indicate that you want to remove bottlenecks and therefore use the numba jit. An array is a continuous sequence of bytes in memory (numpy knows some tricks to change the order without copying by creating views, but that is another topic, see https://towardsdatascience.com/advanced-numpy-master-stride-tricks-with-25-illustrated-exercises-923a9393ab20). If you want to prepend the value x to an array of N elements, you will need to create a new array with N+1 elements, set the first value to x and copy the remaining part. As a side note, a similar argument holds for prepending items to a python list, which is the reason why collections.deque
exists.
Now, in your jit decorated function, you could hope that the compiler understands what you want to do, but writing compilers that always understands what you are trying to do is nearly impossible. Therefore, better be kind to the compiler and help out with the memory layout whenever you know the right choice. Thus, IMHO the "canonical fix" to your example code would be something like the following:
@jit(nopython=True)
def returns(Ft, x, delta):
T = len(x)
rets = np.empty_like(x)
rets[0] = 0
rets[1:T] = Ft[0:T - 1] * x[1:T] - delta * np.abs(Ft[1:T] - Ft[0:T - 1])
return rets
In general, I agree with @Aaron's comment, meaning that you should always be as explicit as possible with input types to any function you call within jit decorated functions. In your case, ask yourself as a compiler "what is [[0], rets]
?". Thinking in strict types, you see a list containing a list of an integer and an array of floating point (or complex) numbers. That is a challenging mixture of types for a compiler. Should the output become an array of integers or floats?