Search code examples
pythonlistnumpyparallel-processingnumba

How to use numba in parallel for varying length lists


I have a function which already is njitted and returns a list of uint32 with varying lengths.

@numba.njit()
def inner_function(A: numba.float32[:]) -> List[int]:
    # ...

That function is called in a loop a large number of times, which could be parallelized since there is no data dependency.

But since the function returns varying lengths a preallocated array can't be used.

What i currently have is a sequential approach using numbas typed list like that

@numba.njit()
def looping_function(A: numba.float32[:,:]) -> List[List[int]]:
    result = List()
    for i in range(A.shape[0]):
        tmp = inner_function(A[i])
        result.append(tmp)
    return result

Is there maybe a way to initialize my the result list (or any other datastructure) in such a way that I can index it instead of appending to it like that:

@numba.njit(parallel=True)
def looping_function(A: numba.float32[:,:]):
    result = ???
    for i in numba.prange(A.shape[0]):
        tmp = inner_function(A[i])
        result[i] = tmp
    return result

Solution

  • I found a solution, it's not very satisfying but works well enough for my usecase. This preallocates as list of lists with a placeholder 0 to ensure type inference. Then I can remove the placeholder and extend as needed.

    @numba.njit(parallel=True)#,forceobj=True)
    def outer(i):
        a = numba.typed.List([numba.typed.List([0]) for _ in range(i)])
        for i in numba.prange(i):
            a[i].pop()
            a[i].extend(inner(i))
        return a