Search code examples
pythonnumba

Python, numba, class with field of its own type


I'm trying to use numba's jitclass on a class that has a field of its own type.

The below code does not work because Foo is not defined.

from numba.experimental import jitclass

@jitclass
class Foo:
    a: int
    pred: 'Foo'

    def __init__(self, pred: 'Foo'):
        self.a = 1
        self.pred = pred


if __name__ == "__main__":
    x = Foo(None)

Replacing Foo with object also does not work. Additionally, I need to be able to pass None in one instance.

Is there a way to make this work?

The only other idea I have is to store pred in an external dictionary.


Solution

  • I think this should do what you want:

    from numba.experimental import jitclass
    from numba import deferred_type, int64, optional
    
    foo_type = deferred_type()
    
    spec = dict()
    spec['a'] = int64
    spec['pred'] = optional(foo_type)
    
    
    @jitclass(spec)
    class Foo(object):
        def __init__(self, pred):
            self.a = 1
            self.pred = pred
    
    foo_type.define(Foo.class_type.instance_type)
    
    x = Foo(None)
    y = Foo(x)
    

    It is based on this example.