Search code examples
pythonloopsrecursionnumbajit

Code to handle arbitrary number of for loops in Python/Numba


I have a function compiled under the 'njit' framework in Numba. It looks like this:

import numpy as np
from numba import njit, types, prange
from numba.typed import List

@njit(cache=CACHE_FLAG)
def find_combinations(target, *arrays):
    """
    Find all combinations of element indices from 7 arrays containing segment lengths that sum to the target update time.
    Returns both the lengths that sum up to the target and the corresponding indices.

    Args:
        target (int): The target sum.
        arrays (tuple of lists): Each list contains segment lengths.

    Returns:
        Tuple of two lists:
            1. Lengths list - contains tuples of segment lengths that sum up to the target.
            2. Indices list - contains tuples of indices in the original lists corresponding to the lengths.
    """
    lengths_list = List()
    indices_list = List()
    for i in prange(len(arrays[0])):
        sum_i = arrays[0][i]
        if sum_i > target:
            continue
        for j in range(len(arrays[1])):
            sum_j = sum_i + arrays[1][j]
            if sum_j > target:
                continue
            for k in range(len(arrays[2])):
                sum_k = sum_j + arrays[2][k]
                if sum_k > target:
                    continue
                for l in range(len(arrays[3])):
                    sum_l = sum_k + arrays[3][l]
                    if sum_l > target:
                        continue
                    for m in range(len(arrays[4])):
                        sum_m = sum_l + arrays[4][m]
                        if sum_m > target:
                            continue
                        for n in range(len(arrays[5])):
                            sum_n = sum_m + arrays[5][n]
                            if sum_n > target:
                                continue
                            for o in range(len(arrays[6])):
                                total = sum_n + arrays[6][o]
                                if total == target:
                                    lengths_list.append(
                                        (
                                            arrays[0][i],
                                            arrays[1][j],
                                            arrays[2][k],
                                            arrays[3][l],
                                            arrays[4][m],
                                            arrays[5][n],
                                            arrays[6][o],
                                        )
                                    )
                                    indices_list.append((i, j, k, l, m, n, o))

    return lengths_list, indices_list

The function works as expected but is designed to work for an 'arrays' input of length 7. The outputs lengths list and indices list contain some number of tuples of length 7 (= length of input). I want to re-write this function to be able to handle inputs of arbitrary lengths.

I tried to rewrite this function using recursion like this:

import numpy as np
from numba import njit, types, prange
from numba.typed import List

@njit
def find_combinations_recursive(arrays, target, current_sum, current_indices, current_lengths, lengths_list, indices_list, depth):
    # Base case: when we've processed all arrays
    if depth == len(arrays):
        if current_sum == target:
            # lengths_list.append(tuple(current_lengths)) # tuple does not work in the Pythonic way-- cannot convert iterable to tuple in Numba?
            # indices_list.append(tuple(current_indices))
            lengths_list.append(current_lengths)
            indices_list.append(current_indices)
        return
    
    # Recursive case: iterate over the current array
    for i in range(len(arrays[depth])):
        new_sum = current_sum + arrays[depth][i]
        if new_sum > target:
            continue
        current_indices.append(i)
        # current_lengths.extend(List([arrays[depth][i]]))
        current_lengths.append(arrays[depth][i])
        find_combinations_recursive(
            arrays, target, new_sum, current_indices, current_lengths, lengths_list, indices_list, depth + 1
        )

@njit(cache=True)
def find_combinations(target, arrays):
    # lengths_list = List()
    # indices_list = List()
    # lengths_list = List.empty_list(tuple([np.int64] * len(arrays)))
    # indices_list = List.empty_list(tuple([np.int64] * len(arrays)))
    lengths_list = List([List([1]*len(arrays))])
    # print(lengths_list)
    # indices_list = List(tuple([1] * len(arrays)))
    indices_list = List([List([1]*len(arrays))])
    # print(indices_list)
    # import pdb; pdb.set_trace()
    k = lengths_list.pop()
    x = indices_list.pop()
    current_indices = List.empty_list(types.int64)
    current_lengths = List.empty_list(types.int64)
    find_combinations_recursive(arrays, target, 0, current_indices, current_lengths, lengths_list, indices_list, 0)
    return lengths_list, indices_list

However, while the function runs it does not produce the expected output-- for some reason, lengths_list does not contain the expected elements and seems to contain a single list of ints rather than multiple lists of ints of length 7. Also, I had a lot of trouble with Numba's List() data structure and eventually resorted to a sort of hacky way of initializing where I give it an initial value to set the type and then pop this value out to obtain an empty list of desired type. What do I need to fix to get this all to work coherently?


Solution

  • Your recursive implementation was almost complete, except for one bug.

    Here is the problem:

        # Recursive case: iterate over the current array
        for i in range(len(arrays[depth])):
            new_sum = current_sum + arrays[depth][i]
            if new_sum > target:
                continue
            current_indices.append(i)  # <-- HERE!
    

    Since you are adding to current_indices in-place, it keeps growing longer with each loop iteration. But, this list should grow longer with each recursive call, not with each loop iteration. This was your only mistake.

    Regarding how to create a list of lists, although there isn't any documentation that I know of, it's quite simple. You can do it like this:

    indices_list = List.empty_list(List.empty_list(types.int64))
    

    Or you can even use another typed.List like this:

    current_indices = List.empty_list(types.int64)
    indices_list = List.empty_list(current_indices)
    

    Here is the code with the problem fixed. Although it wasn't part of your request, I modified the code to return only the indices. Since you can easily retrieve the values from arrays using the indices, separating these two concerns makes the code simpler.

    import numpy as np
    from numba import njit, types
    from numba.typed import List
    
    
    @njit
    def find_combinations_recursive(arrays, target, current_sum, current_indices, indices_list, depth):
        if depth == len(arrays):
            if current_sum == target:
                indices_list.append(current_indices)
            return
    
        for i in range(len(arrays[depth])):
            new_sum = current_sum + arrays[depth][i]
            if new_sum > target:
                continue
    
            # This is an in-place operation, and it will grow longer for each iteration.
            # current_indices.append(i)
    
            # This is a copy operation, and a new list is created for each iteration.
            new_indices = current_indices.copy()
            new_indices.append(i)
    
            find_combinations_recursive(arrays, target, new_sum, new_indices, indices_list, depth + 1)
    
    
    @njit(cache=True)
    def find_combinations(target, arrays):
        current_indices = List.empty_list(types.int64)
    
        # To create a list of lists, like this.
        indices_list = List.empty_list(List.empty_list(types.int64))
    
        # Or you can do it like this.
        # indices_list = List.empty_list(current_indices)
    
        find_combinations_recursive(arrays, target, 0, current_indices, indices_list, 0)
        return indices_list
    
    
    @njit
    def get_lengths_list(arrays, indices_list):
        """Retrieve values from arrays using lists of indices."""
        lengths_list = []
        for indices in indices_list:
            lengths_list.append([arrays[i][indices[i]] for i in range(len(arrays))])
        return lengths_list
    
    
    def main():
        np.random.seed(0)
        arrays = np.random.randint(0, 100, (7, 7))
        target = arrays.max(axis=1).sum()
        print("arrays:")
        print(arrays)
        print(f"{target=}")
        indices_list = find_combinations(target, arrays)
        lengths_list = get_lengths_list(arrays, indices_list)
        print(f"{indices_list=}")
        print(f"{lengths_list=}")
    
    
    if __name__ == '__main__':
        main()
    

    Result:

    arrays:
    [[44 47 64 67 67  9 83]
     [21 36 87 70 88 88 12]
     [58 65 39 87 46 88 81]
     [37 25 77 72  9 20 80]
     [69 79 47 64 82 99 88]
     [49 29 19 19 14 39 32]
     [65  9 57 32 31 74 23]]
    target=np.int64(561)
    indices_list=ListType[ListType[int64]]([[6, 4, 5, 6, 5, 0, 5], [6, 5, 5, 6, 5, 0, 5]])
    lengths_list=[[83, 88, 88, 80, 99, 49, 74], [83, 88, 88, 80, 99, 49, 74]]