I have this piece code, using Numba to speed up processing. Basically, particle_dtype is defined to make code ran using Numba. However, TypingError is reported, saying "Cannot determine Numba type of <class 'function'>". I cannot figure out where is the problem.
import numpy
from numba import njit
particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'],
'formats':[numpy.double,
numpy.double,
numpy.double,
numpy.double,
numpy.double]})
def create_n_random_particles(n, m, domain=1):
parts = numpy.zeros((n), dtype=particle_dtype)
parts['x'] = numpy.random.random(size=n) * domain
parts['y'] = numpy.random.random(size=n) * domain
parts['z'] = numpy.random.random(size=n) * domain
parts['m'] = m
parts['phi'] = 0.0
return parts
def distance(se, other):
return numpy.sqrt(numpy.square(se['x'] - other['x']) +
numpy.square(se['y'] - other['y']) +
numpy.square(se['z'] - other['z']))
parts = create_n_random_particles(10, .001, 1)
@njit
def direct_sum(particles):
for i, target in enumerate(particles):
for j in range(particles.shape[0]):
if i == j:
continue
source = particles[j]
r = distance(target, source)
# target['phi'] += source['m'] / r
target['phi'] = target['phi'] + source['m'] / r
return(target['phi'])
print(direct_sum(parts) )
I guess it's because non-supported functions or operations are used somewhere, but I cannot find it. Thanks for your help.
direct_sum
which is a JITed function cannot call distance
because it is not JITed (pure-Python function). You need to use the decorator @njit
on distance
too.