Search code examples
pythonnumba

numba typeerror on higher dimensional structured numpy datatypes


The following code compiles and executes correctly:

import numpy as np
from numba import njit

Particle = np.dtype([ ('position', 'f4'), ('velocity', 'f4')])

arr = np.zeros(2, dtype=Particle)

@njit
def f(x):
    x[0]['position'] = x[1]['position'] + x[1]['velocity'] * 0.2 + 1.
    
f(arr)

However, making the datatype more highly dimensional causes this code to fail when compiling (but works without @njit):

import numpy as np
from numba import njit

Particle = np.dtype([
            ('position', 'f4', (2,)),
            ('velocity', 'f4', (2,))
          ])

arr = np.zeros(2, dtype=Particle)

@njit
def f(x):
    x[0]['position'] = x[1]['position'] + x[1]['velocity'] * 0.2 + 1.
    
f(arr)

With the following error:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(Record(position[type=nestedarray(float32, (2,));offset=0],velocity[type=nestedarray(float32, (2,));offset=8];16;False), Literal[str](position), array(float64, 1d, C))
 
There are 16 candidate implementations:
    - Of which 16 did not match due to:
    Overload of function 'setitem': File: <numerous>: Line N/A.
      With argument(s): '(Record(position[type=nestedarray(float32, (2,));offset=0],velocity[type=nestedarray(float32, (2,));offset=8];16;False), unicode_type, array(float64, 1d, C))':
     No match.

During: typing of staticsetitem at /tmp/ipykernel_21235/2952285515.py (13)

File "../../../../tmp/ipykernel_21235/2952285515.py", line 13:
<source missing, REPL/exec in use?>

Any thoughts on how to remedy the later one? I would like to use more highly dimensionalized datatypes.


Solution

  • You can try to use [:] to set values of the array:

    import numpy as np
    from numba import njit
    
    Particle = np.dtype([("position", "f4", (2,)), ("velocity", "f4", (2,))])
    
    arr = np.zeros(2, dtype=Particle)
    
    
    @njit
    def f(x):
        pos_0 = x[0]["position"]
        pos_0[:] = x[1]["position"] + x[1]["velocity"] * 0.2 + 1.0
    
        #x[0]["position"][:] = ... works too
    
    f(arr)
    print(arr)
    

    Prints:

    [([1., 1.], [0., 0.]) ([0., 0.], [0., 0.])]