Search code examples
pythonnumbajit

TypingError for Numba


I have this piece code, using Numba to speed up processing. Basically, particle_dtype is defined to make code ran using Numba. However, TypingError is reported, saying "Cannot determine Numba type of <class 'function'>". I cannot figure out where is the problem.

import numpy
from numba import njit

particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'], 
                             'formats':[numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double]}) 


def create_n_random_particles(n, m, domain=1):
    parts = numpy.zeros((n), dtype=particle_dtype)
    parts['x'] = numpy.random.random(size=n) * domain
    parts['y'] = numpy.random.random(size=n) * domain
    parts['z'] = numpy.random.random(size=n) * domain
    parts['m'] = m
    parts['phi'] = 0.0

    return parts


def distance(se, other):
    return numpy.sqrt(numpy.square(se['x'] - other['x']) + 
                      numpy.square(se['y'] - other['y']) + 
                      numpy.square(se['z'] - other['z']))


parts = create_n_random_particles(10, .001, 1)


@njit
def direct_sum(particles):
    for i, target in enumerate(particles):
        for j in range(particles.shape[0]):
            if i == j:
                continue
            source = particles[j]
            r = distance(target, source)
            # target['phi'] += source['m'] / r
            target['phi'] = target['phi'] + source['m'] / r
            return(target['phi'])
            
print(direct_sum(parts) ) 

I guess it's because non-supported functions or operations are used somewhere, but I cannot find it. Thanks for your help.


Solution

  • direct_sum which is a JITed function cannot call distance because it is not JITed (pure-Python function). You need to use the decorator @njit on distance too.