Search code examples
pythonmultidimensional-arrayindexingjitnumba

Indexing multidimensional numpy array inside numba's jitclass


I'm trying to insert a small multidimensional array into a larger one inside a numba jitclass. The small array is set specific positions of the larger array defined by an index list.

The following MWE shows the problem without numba - everything works as expected

import numpy as np

class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution 1 using pure python
    def nonNumbaFunction1(self, idx, values):
        self.A[idx[:, None], idx] = values

    # solution 2 using pure python
    def nonNumbaFunction2(self, idx, values):
        self.A[np.ix_(idx, idx)] = values

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.nonNumbaFunction1(idx, values)
    print(f'A =\n{obj.A}')

    obj.nonNumbaFunction2(idx, values)
    print(f'A =\n{obj.A}')

Both functions nonNumbaFunction1 and nonNumbaFunction2 do not work inside a numba class. So my current solution looks like this which is not really nice in my opinion

import numpy as np

from numba import jitclass      
from numba import int64, float64
from collections import OrderedDict

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution for numba jitclass
    def numbaFunction(self, idx, values):
        for i in range(len(values)):
            idxi = idx[i]
            for j in range(len(values)):
                idxj = idx[j]
                self.A[idxi, idxj] = values[i, j]

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.numbaFunction(idx, values)
    print(f'A =\n{obj.A}')

So my questions are:

  • Does anyone know a solution to this indexing in numba or is there another vectorized solution?
  • Is there a faster solution for nonNumbaFunction1?

It might be useful to know that inserted array is small (4x4 to 10x10), but this indexing appears in nested loops so it has to be quiet fast as well! Later I need a similar indexing for three dimensional objects too.


Solution

  • Because of limitations on numba's indexing support, I don't think you can do any better than writing out the for loops yourself. To make it generic across dimensions, you could use the generated_jit decorator to specialize. Something like this:

    def set_2d(target, values, idx):
        for i in range(values.shape[0]):
            for j in range(values.shape[1]):
                target[idx[i], idx[j]] = values[i, j]
    
    def set_3d(target, values, idx):
        for i in range(values.shape[0]):
            for j in range(values.shape[1]):
                for k in range(values.shape[2]):
                    target[idx[i], idx[j], idx[k]] = values[i, j, l]
    
    @numba.generated_jit
    def set_nd(target, values, idx):
        if target.ndim == 2:
            return set_2d
        elif target.ndim == 3:
            return set_3d
    

    Then, this could be used in your jitclass

    specs = OrderedDict()
    specs['A'] = float64[:, :]
    
    @jitclass(specs)
    class NumbaClass(object):
        def __init__(self, n, m):
            self.A = np.zeros((n, m))
        def numbaFunction(self, idx, values):
            set_nd(self.A, values, idx)