Search code examples
pythonnumpyconcatenationnumbajit

Concatenate python tuples in numba


I'm looking to fill up an array of zeros with numbers taken from some tuples, easy as that.

Usually, this is not a problem even when the tuples are not the same length (which is the point here). but it seems it won't compile and I cannot figure out a solution.

from numba import jit    

def cant_jit(ls):

    # Array total lenth
    tl = 6
    # Type
    typ = np.int64

    # Array to modify and return
    start = np.zeros((len(ls), tl), dtype=typ)

    for i in range(len(ls)):

        a = np.array((ls[i]), dtype=typ)
        z = np.zeros((tl - len(ls[i]),), dtype=typ)
        c = np.concatenate((a, z))
        start[i] = c

    return start

# Uneven tuples would be no problem in vanilla
cant_jit(((2, 4), (6, 8, 4)))


jt = jit(cant_jit)    
# working fine
jt(((2, 4), (6, 8)))
# non working
jt(((2, 4), (6, 8, 4)))

Within the error.

getitem(Tuple(UniTuple(int64 x 3), UniTuple(int64 x 2)), int64) There are 22 candidate implementations: - Of which 22 did not match due to: Overload of function 'getitem': File: : Line N/A. With argument(s): '(Tuple(UniTuple(int64 x 3), UniTuple(int64 x 2)), int64)': No match.

I tried some things here with no success. Does someone know a way around this so the function can be compiled and still do its thing?


Solution

  • This isn't possible as far as I can tell, numba documentation tells us that nested tuples that aren't of equal length aren't legal unless you use forceobj=True. You can't even unpack *args which is frustrating. You will always receive that warning/error:

    Just add that argument to jit() like this:

    from numba import jit    
    import numpy as np
    
    def cant_jit(ls):
    
        # Array total lenth
        tl = 6
        # Type
        typ = np.int64
    
        # Array to modify and return
        start = np.zeros((len(ls), tl), dtype=typ)
    
        for i in range(len(ls)):
    
            a = np.array((ls[i]), dtype=typ)
            z = np.zeros((tl - len(ls[i]),), dtype=typ)
            c = np.concatenate((a, z))
            start[i] = c
    
        return start
    
    # Uneven tuples would be no problem in vanilla
    cant_jit(((2, 4), (6, 8, 4)))
    
    
    jt = jit(cant_jit, forceobj=True)    
    # working fine
    jt(((2, 4), (6, 8)))
    # now working
    jt(((2, 4), (6, 8, 4)))

    This works but it's kind of pointless and you may as well use core python.