Search code examples
pythonnumba

Numba - cannot infer type for List()


I am trying to use numba to speed up on of the fuzzy search function of the python package. My plan is to first use njit sequentially and then move to parallelism if the goal is not met. So I am converting the original function in library to numba supported types. I am using typed List instead of the normal python List. The numba throws the error "Cannot infer the type of variable 'candidates', have imprecise type: ListType[undefined]". I am confused as to why this error occurs? Isnt this the way to declare a typed list variable?

I am new to numba so any suggestions regarding alternate efficient ways to speed this up are welcome.

@njit
def make_char2first_subseq_index(subsequence, max_l_dist):
    d = Dict.empty(
        key_type=types.unicode_type,
        value_type=numba.int64,
    )
    for (index, char) in list(enumerate(subsequence[:max_l_dist + 1])):
        d[char] = index
    return d


@njit
def find_near_matches_levenshtein_linear_programming(subsequence, sequence,
                                                     max_l_dist):
    if not subsequence:
        raise ValueError('Given subsequence is empty!')

    subseq_len = len(subsequence)

    def make_match(start, end, dist):
        # return Match(start, end, dist, matched=sequence[start:end])
        return str(start) + " " + str(end) + " " + str(dist) + " " + str(sequence[start:end])

    if max_l_dist >= subseq_len:
        for index in range(len(sequence) + 1):
            return make_match(index, index, subseq_len)

    # optimization: prepare some often used things in advance
    char2first_subseq_index = make_char2first_subseq_index(subsequence,
                                                           max_l_dist)

    candidates = List()
    for index, char in enumerate(sequence):
        # print('/n new loop and the character is ', char)
        new_candidates = List()

        idx_in_subseq = char2first_subseq_index.get(char, None)
        # print("idx_in_subseq ", idx_in_subseq)
        if idx_in_subseq is not None:
            if idx_in_subseq + 1 == subseq_len:
                return make_match(index, index + 1, idx_in_subseq)
            else:
                new_candidates.append(List(index, idx_in_subseq + 1, idx_in_subseq))

        # print(candidates, " new candidates ", new_candidates)
        for cand in candidates:
            # if this sequence char is the candidate's next expected char
            if subsequence[cand[1]] == char:
                # if reached the end of the subsequence, return a match
                if cand[1] + 1 == subseq_len:
                    return make_match(cand[0], index + 1, cand[2])
                # otherwise, update the candidate's subseq_index and keep it
                else:
                    new_candidates.append(List(cand[0], cand[1] + 1, cand[2]))

            # if this sequence char is *not* the candidate's next expected char
            else:
                # we can try skipping a sequence or sub-sequence char (or both),
                # unless this candidate has already skipped the maximum allowed
                # number of characters
                if cand[2] == max_l_dist:
                    continue

                # add a candidate skipping a sequence char
                new_candidates.append(List(cand[0], cand[1], cand[2] + 1))

                if index + 1 < len(sequence) and cand[1] + 1 < subseq_len:
                    # add a candidate skipping both a sequence char and a
                    # subsequence char
                    new_candidates.append(List(cand[0], cand[1] + 1, cand[2] + 1))

                # try skipping subsequence chars
                for n_skipped in range(1, max_l_dist - cand[2] + 1):
                    # if skipping n_skipped sub-sequence chars reaches the end
                    # of the sub-sequence, yield a match
                    if cand[1] + n_skipped == subseq_len:
                        return make_match(cand[0], index + 1, cand[2] + n_skipped)
                        break
                    # otherwise, if skipping n_skipped sub-sequence chars
                    # reaches a sub-sequence char identical to this sequence
                    # char, add a candidate skipping n_skipped sub-sequence
                    # chars
                    elif subsequence[cand[1] + n_skipped] == char:
                        # if this is the last char of the sub-sequence, yield
                        # a match
                        if cand[1] + n_skipped + 1 == subseq_len:
                            return make_match(cand[0], index + 1,
                                             cand[2] + n_skipped)
                        # otherwise add a candidate skipping n_skipped
                        # subsequence chars
                        else:
                            new_candidates.append(List(cand[0], cand[1] + 1 + n_skipped, cand[2] + n_skipped))
                        break
                # note: if the above loop ends without a break, that means that
                # no candidate could be added / yielded by skipping sub-sequence
                # chars

        candidates = new_candidates

    for cand in candidates:
        dist = cand[2] + subseq_len - cand[1]
        if dist <= max_l_dist:
            return make_match(cand[0], len(sequence), dist)

Solution

  • The error message is very precise and it's pointing to the specific problem. Numba typed.List uses an homogeneous data type so it needs to know the type.

    You can create a typed list by initializing it:

    list_of_ints = nb.typed.List([1,2,3])
    

    Or create an empty one using the empty_list() factory to declare its type:

    empty_list_of_floats = nb.typed.List.empty_list(nb.f8)
    

    Or create an empty one and immediately appending an element:

    another_list_of_ints = nb.typed.List()
    another_list_of_ints.append(1)
    

    Or any combination:

    list_of_lists_of_floats = nb.typed.List()
    list_of_lists_of_floats.append(nb.typed.List.empty_list(nb.f8))
    list_of_lists_of_floats[0].append(1)