Search code examples
pythonnumpyrecursionnumba

Can this recursive function be turned into an iterative function with similar performance?


I am writing a function in python using numba to label objects in a 2D or 3D array, meaning all orthogonally connected cells with the same value in the input array will be given a unique label from 1 to N in the output array, where N is the number of orthogonally connected groups. It is very similar to functions such as scipy.ndimage.label and similar functions in libraries such as scikit-image, but those functions label all orthogonally connected non-zero groups of cells, so it would merge connected groups with different values, which I don't want. For example, given this input:

[0 0 7 7 0 0
 0 0 7 0 0 0
 0 0 0 0 0 7
 0 6 6 0 0 7
 0 0 4 4 0 0]

The scipy function would return

[0 0 1 1 0 0
 0 0 1 0 0 0
 0 0 0 0 0 3
 0 2 2 0 0 3
 0 0 2 2 0 0]

Notice that the 6s and 4s were merged into the label 2. I want them to be labeled as separate groups, e.g.:

[0 0 1 1 0 0
 0 0 1 0 0 0
 0 0 0 0 0 4
 0 2 2 0 0 4
 0 0 3 3 0 0]

I asked this about a year ago and have been using the solution in the accepted answer, however I am working on optimizing the runtime of my code and am revisiting this problem.

For the data size I generally work with, the linked solution takes about 1m30s to run. I wrote the following recursive algorithm which takes about 30s running as regular python and with numba's JIT runs in 1-2s (side note, I hate that adjacent function, any tips to make it less messy while still numba-compatible would be appreciated):

@numba.njit
def adjacent(idx, shape):
    coords = []
    if len(shape) > 2:
        if idx[0] < shape[0] - 1:
            coords.append((idx[0] + 1, idx[1], idx[2]))
        if idx[0] > 0:
            coords.append((idx[0] - 1, idx[1], idx[2]))
        if idx[1] < shape[1] - 1:
            coords.append((idx[0], idx[1] + 1, idx[2]))
        if idx[1] > 0:
            coords.append((idx[0], idx[1] - 1, idx[2]))
        if idx[2] < shape[2] - 1:
            coords.append((idx[0], idx[1], idx[2] + 1))
        if idx[2] > 0:
            coords.append((idx[0], idx[1], idx[2] - 1))
    else:
        if idx[0] < shape[0] - 1:
            coords.append((idx[0] + 1, idx[1]))
        if idx[0] > 0:
            coords.append((idx[0] - 1, idx[1]))
        if idx[1] < shape[1] - 1:
            coords.append((idx[0], idx[1] + 1))
        if idx[1] > 0:
            coords.append((idx[0], idx[1] - 1))
    return coords


@numba.njit
def apply_label(labels, decoded_image, current_label, idx):
    labels[idx] = current_label
    for aidx in adjacent(idx, labels.shape):
        if decoded_image[aidx] == decoded_image[idx] and labels[aidx] == 0:
            apply_label(labels, decoded_image, current_label, aidx)


@numba.njit
def label_image(decoded_image):
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    for idx in zip(*np.where(decoded_image >= 0)):
        if labels[idx] == 0:
            current_label += 1
            apply_label(labels, decoded_image, current_label, idx)
    return labels, current_label

This worked for some data, but crashed on other data and I found the issue is that when there are very large objects to label, the recursion limit is reached. I tried to rewrite label_image to not use recursion, but it now takes ~10s with numba. Still a huge improvement from where I started, but it seems like it should be possible to get the same performance as the recursive version. Here is my iterative version:

@numba.njit
def label_image(decoded_image):
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    for idx in zip(*np.where(decoded_image >= 0)):
        if labels[idx] == 0:
            current_label += 1
            idxs = [idx]
            while idxs:
                cidx = idxs.pop()
                if labels[cidx] == 0:
                    labels[cidx] = current_label
                    for aidx in adjacent(cidx, labels.shape):
                        if labels[aidx] == 0 and decoded_image[aidx] == decoded_image[idx]:
                            idxs.append(aidx)
    return labels, current_label

Is there a way I can improve this?


Solution

  • Can this recursive function be turned into an iterative function with similar performance?

    Turning this into an iterative function is straightforward, considering it's just a simple depth-first search (you could also use a breadth-first search using a queue instead of a stack here, both work). Simply use a stack to keep track of the nodes to visit. Here's a general solution that works with any number of dimensions:

    def label_image(decoded_image):
        shape = decoded_image.shape
        labels = np.zeros_like(decoded_image, dtype=np.uint32)
        current_label = 0
        for idx in zip(*np.where(decoded_image > 0)):
            if labels[idx] == 0:
                current_label += 1
                stack = [idx]
                while stack:
                    top = stack.pop()
                    labels[top] = current_label
                    for i in range(0, len(shape)):
                        if top[i] > 0:
                            neighbor = list(top)
                            neighbor[i] -= 1
                            neighbor = tuple(neighbor)
                            if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
                                stack.append(neighbor)
                        if top[i] < shape[i] - 1:
                            neighbor = list(top)
                            neighbor[i] += 1
                            neighbor = tuple(neighbor)
                            if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
                                stack.append(neighbor)
        return labels
    

    Adding or subtracting one from the i-th component of the tuple is awkward though (I'm going over a temporary list here) and numba doesn't accept it (type error). One simple solution would be to explicitly write versions for 2d and 3d, which will likely greatly help performance:

    @numba.njit
    def label_image_2d(decoded_image):
        w, h = decoded_image.shape
        labels = np.zeros_like(decoded_image, dtype=np.uint32)
        current_label = 0
        for idx in zip(*np.where(decoded_image > 0)):
            if labels[idx] == 0:
                current_label += 1
                stack = [idx]
                while stack:
                    x, y = stack.pop()
                    if decoded_image[x, y] != decoded_image[idx] or labels[x, y] != 0:
                        continue # already visited or not part of this group
                    labels[x, y] = current_label
                    if x > 0: stack.append((x-1, y))
                    if x+1 < w: stack.append((x+1, y))
                    if y > 0: stack.append((x, y-1))
                    if y+1 < h: stack.append((x, y+1))
        return labels
    
    @numba.njit
    def label_image_3d(decoded_image):
        w, h, l = decoded_image.shape
        labels = np.zeros_like(decoded_image, dtype=np.uint32)
        current_label = 0
        for idx in zip(*np.where(decoded_image > 0)):
            if labels[idx] == 0:
                current_label += 1
                stack = [idx]
                while stack:
                    x, y, z = stack.pop()
                    if decoded_image[x, y, z] != decoded_image[idx] or labels[x, y, z] != 0:
                        continue # already visited or not part of this group
                    labels[x, y, z] = current_label
                    if x > 0: stack.append((x-1, y, z))
                    if x+1 < w: stack.append((x+1, y, z))
                    if y > 0: stack.append((x, y-1, z))
                    if y+1 < h: stack.append((x, y+1, z))
                    if z > 0: stack.append((x, y, z-1))
                    if z+1 < l: stack.append((x, y, z+1))
        return labels
    
    def label_image(decoded_image):
        dim = len(decoded_image.shape)
        if dim == 2:
            return label_image_2d(decoded_image)
        assert dim == 3
        return label_image_3d(decoded_image)
    

    Note also that the iterative solution doesn't suffer from stack limits: np.full((100,100,100), 1) works just fine in the iterative solution, but fails in the recursive solution (segfaults if using numba).

    Doing a very rudimentary benchmark of

    for i in range(1, 10000):
        label_image(np.full((20,20,20), i))
    

    (many iterations to minimize the impact of JIT, could also do a few warmup runs, then start measuring time or similar)

    The iterative solution seems to be several times faster (about 5x on my machine see below). You could probably optimize the recursive solution and get it to a comparable speed, f.e. by avoiding the temporary coords list or by changing the np.where to > 0.

    I don't know how well numba can optimize the zipped np.where. For further optimization, you could consider (and benchmark) using explicit nested for x in range(0, w): for y in range(0, h): loops there.


    To remain competitive with the merge strategy proposed by Nick, I've optimized this further, picking some low hanging fruit:

    • Convert the zip to explicit loops with continue rather than np.where.
    • Store decoded_image[idx] in a local variable (ideally shouldn't matter, but doesn't hurt).
    • Reuse the stack. This prevents unnecessary (re)allocations and GC strain. It could further be considered to provide an initial capacity for the stack (of w*h or w*h*l respectively).
    @numba.njit
    def label_image_2d(decoded_image):
        w, h = decoded_image.shape
        labels = np.zeros_like(decoded_image, dtype=np.uint32)
        current_label = 0
        stack = []
        for sx in range(0, w):
            for sy in range(0, h):
                start = (sx, sy)
                image_label = decoded_image[start]
                if image_label <= 0 or labels[start] != 0:
                    continue
                current_label += 1
                stack.append(start)
                while stack:
                    x, y = stack.pop()
                    if decoded_image[x, y] != image_label or labels[x, y] != 0:
                        continue # already visited or not part of this group
                    labels[x, y] = current_label
                    if x > 0: stack.append((x-1, y))
                    if x+1 < w: stack.append((x+1, y))
                    if y > 0: stack.append((x, y-1))
                    if y+1 < h: stack.append((x, y+1))
        return labels
    
    @numba.njit
    def label_image_3d(decoded_image):
        w, h, l = decoded_image.shape
        labels = np.zeros_like(decoded_image, dtype=np.uint32)
        current_label = 0
        stack = []
        for sx in range(0, w):
            for sy in range(0, h):
                for sz in range(0, l):
                    start = (sx, sy, sz)
                    image_label = decoded_image[start]
                    if image_label <= 0 or labels[start] != 0:
                        continue
                    current_label += 1
                    stack.append(start)
                    while stack:
                        x, y, z = stack.pop()
                        if decoded_image[x, y, z] != image_label or labels[x, y, z] != 0:
                            continue # already visited or not part of this group
                        labels[x, y, z] = current_label
                        if x > 0: stack.append((x-1, y, z))
                        if x+1 < w: stack.append((x+1, y, z))
                        if y > 0: stack.append((x, y-1, z))
                        if y+1 < h: stack.append((x, y+1, z))
                        if z > 0: stack.append((x, y, z-1))
                        if z+1 < l: stack.append((x, y, z+1))
        return labels
    

    I then cobbled together a benchmark to compare the four approaches (original recursive, old iterative, new iterative, merge-based), putting them in four different modules:

    import numpy as np
    import timeit
    
    import rec
    import iter_old
    import iter_new
    import merge
    
    shape = (100, 100, 100)
    n = 20
    for module in [rec, iter_old, iter_new, merge]:
        print(module)
    
        label_image = module.label_image
        # Trigger compilation of 2d & 3d functions
        label_image(np.zeros((1, 1)))
        label_image(np.zeros((1, 1, 1)))
    
        i = 0
        def test_full():
            global i
            i += 1
            label_image(np.full(shape, i))
        print("single group:", timeit.timeit(test_full, number=n))
        print("random (few groups):", timeit.timeit(
            lambda: label_image(np.random.randint(low = 1, high = 10, size = shape)),
            number=n))
        print("random (many groups):", timeit.timeit(
            lambda: label_image(np.random.randint(low = 1, high = 400, size = shape)),
            number=n))
        print("only groups:", timeit.timeit(
            lambda: label_image(np.arange(np.prod(shape)).reshape(shape)),
            number=n))
    

    This outputs something like

    <module 'rec' from '...'>
    single group: 32.39212468900041
    random (few groups): 14.648884047001047
    random (many groups): 13.304533919001187
    only groups: 13.513677138000276
    <module 'iter_old' from '...'>
    single group: 10.287227957000141
    random (few groups): 17.37535468200076
    random (many groups): 14.506630064999626
    only groups: 13.132202609998785
    <module 'iter_new' from '...'>
    single group: 7.388022166000155
    random (few groups): 11.585243002000425
    random (many groups): 9.560101995000878
    only groups: 8.693653742000606
    <module 'merge' from '...'>
    single group: 14.657021331999204
    random (few groups): 14.146574055999736
    random (many groups): 13.412314713001251
    only groups: 12.642367746000673
    

    It seems to me that the improved iterative approach may be better. Note that the original rudimentary benchmark seems to be the worst case for the recursive variant. In general the difference isn't as large.

    The tested array is pretty small (20³). If I test with a larger array (100³), and a smaller n (20), I get roughly the following results (rec is omitted because due to stack limits, it would segfault):

    <module 'iter_old' from '...'>
    single group: 3.5357716739999887
    random (few groups): 4.931695729999774
    random (many groups): 3.4671142009992764
    only groups: 3.3023930709987326
    <module 'iter_new' from '...'>
    single group: 2.45903080700009
    random (few groups): 2.907660342001691
    random (many groups): 2.309699692999857
    only groups: 2.052835552000033
    <module 'merge' from '...'>
    single group: 3.7620838259990705
    random (few groups): 3.3524249689999124
    random (many groups): 3.126650959999097
    only groups: 2.9456547739991947
    

    The iterative approach still seems to be more efficient.