Search code examples
pythonnumpynumba

how to create a list of numpy arrays in jitclass


I want to create a jitclass, that will store some numpy arrays. And i don't know exactly how many of them. So I want to create a list of numpy arrays. I am confused in numba types, but found some strange solution. This runs normal.

import numba
from numba import types, typed, typeof
from numba.experimental import jitclass
import numpy as np


spec = [
    ('test', typeof(typed.List.empty_list(numba.int64[:])))
]

@jitclass(spec)
class myLIST(object):
    def __init__ (self, haha=typed.List.empty_list(numba.int64[:])):
        self.test = haha
        self.test.append(np.asarray([0]))

    def dump(self):
        self.test.append(np.asarray([1]))
        print(self.test)

a = myLIST()
a.dump()

but when I remove redundant variable, it fails.

spec = [
    ('test', typeof(typed.List.empty_list(numba.int64[:])))
]

@jitclass(spec)
class myLIST(object):
    def __init__ (self):
        self.test = typed.List.empty_list(numba.int64[:])
        self.test.append(np.asarray([0]))

    def dump(self):
        self.test.append(np.asarray([1]))
        print(self.test)

a = myLIST()
a.dump()

Why this happens?


Solution

  • It seems that declaring an array type as nb.int64[:] doesn't provide enough information to create the class unless you create an instance (the default value for haha) that Numba can use to infer the type.

    Instead, you can declare:

    int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")
    spec = [('test', nb.typeof(nb.typed.List.empty_list(int_vector)))]
    

    Or shorter:

    int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")
    spec = [('test', nb.types.ListType(int_vector))]
    

    Or, if you can use type annotations:

    int_vector = nb.types.Array(dtype=nb.int64, ndim=1, layout="C")
    
    @nb.experimental.jitclass
    class my_list:
    
        test: nb.types.ListType(int_vector)
    
        def __init__(self):
            self.test = nb.typed.List.empty_list(int_vector)
            self.test.append(np.array([0]))
    
        def dump(self):
            self.test.append(np.array([1]))
            print(self.test)
    
    a = my_list()
    a.dump()