I am writing a function in python using numba to label objects in a 2D or 3D array, meaning all orthogonally connected cells with the same value in the input array will be given a unique label from 1 to N in the output array, where N is the number of orthogonally connected groups. It is very similar to functions such as scipy.ndimage.label
and similar functions in libraries such as scikit-image, but those functions label all orthogonally connected non-zero groups of cells, so it would merge connected groups with different values, which I don't want. For example, given this input:
[0 0 7 7 0 0
0 0 7 0 0 0
0 0 0 0 0 7
0 6 6 0 0 7
0 0 4 4 0 0]
The scipy function would return
[0 0 1 1 0 0
0 0 1 0 0 0
0 0 0 0 0 3
0 2 2 0 0 3
0 0 2 2 0 0]
Notice that the 6s and 4s were merged into the label 2
. I want them to be labeled as separate groups, e.g.:
[0 0 1 1 0 0
0 0 1 0 0 0
0 0 0 0 0 4
0 2 2 0 0 4
0 0 3 3 0 0]
I asked this about a year ago and have been using the solution in the accepted answer, however I am working on optimizing the runtime of my code and am revisiting this problem.
For the data size I generally work with, the linked solution takes about 1m30s to run. I wrote the following recursive algorithm which takes about 30s running as regular python and with numba's JIT runs in 1-2s (side note, I hate that adjacent function, any tips to make it less messy while still numba-compatible would be appreciated):
@numba.njit
def adjacent(idx, shape):
coords = []
if len(shape) > 2:
if idx[0] < shape[0] - 1:
coords.append((idx[0] + 1, idx[1], idx[2]))
if idx[0] > 0:
coords.append((idx[0] - 1, idx[1], idx[2]))
if idx[1] < shape[1] - 1:
coords.append((idx[0], idx[1] + 1, idx[2]))
if idx[1] > 0:
coords.append((idx[0], idx[1] - 1, idx[2]))
if idx[2] < shape[2] - 1:
coords.append((idx[0], idx[1], idx[2] + 1))
if idx[2] > 0:
coords.append((idx[0], idx[1], idx[2] - 1))
else:
if idx[0] < shape[0] - 1:
coords.append((idx[0] + 1, idx[1]))
if idx[0] > 0:
coords.append((idx[0] - 1, idx[1]))
if idx[1] < shape[1] - 1:
coords.append((idx[0], idx[1] + 1))
if idx[1] > 0:
coords.append((idx[0], idx[1] - 1))
return coords
@numba.njit
def apply_label(labels, decoded_image, current_label, idx):
labels[idx] = current_label
for aidx in adjacent(idx, labels.shape):
if decoded_image[aidx] == decoded_image[idx] and labels[aidx] == 0:
apply_label(labels, decoded_image, current_label, aidx)
@numba.njit
def label_image(decoded_image):
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image >= 0)):
if labels[idx] == 0:
current_label += 1
apply_label(labels, decoded_image, current_label, idx)
return labels, current_label
This worked for some data, but crashed on other data and I found the issue is that when there are very large objects to label, the recursion limit is reached. I tried to rewrite label_image
to not use recursion, but it now takes ~10s with numba. Still a huge improvement from where I started, but it seems like it should be possible to get the same performance as the recursive version. Here is my iterative version:
@numba.njit
def label_image(decoded_image):
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image >= 0)):
if labels[idx] == 0:
current_label += 1
idxs = [idx]
while idxs:
cidx = idxs.pop()
if labels[cidx] == 0:
labels[cidx] = current_label
for aidx in adjacent(cidx, labels.shape):
if labels[aidx] == 0 and decoded_image[aidx] == decoded_image[idx]:
idxs.append(aidx)
return labels, current_label
Is there a way I can improve this?
Can this recursive function be turned into an iterative function with similar performance?
Turning this into an iterative function is straightforward, considering it's just a simple depth-first search (you could also use a breadth-first search using a queue instead of a stack here, both work). Simply use a stack to keep track of the nodes to visit. Here's a general solution that works with any number of dimensions:
def label_image(decoded_image):
shape = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
top = stack.pop()
labels[top] = current_label
for i in range(0, len(shape)):
if top[i] > 0:
neighbor = list(top)
neighbor[i] -= 1
neighbor = tuple(neighbor)
if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
stack.append(neighbor)
if top[i] < shape[i] - 1:
neighbor = list(top)
neighbor[i] += 1
neighbor = tuple(neighbor)
if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
stack.append(neighbor)
return labels
Adding or subtracting one from the i-th component of the tuple is awkward though (I'm going over a temporary list here) and numba doesn't accept it (type error). One simple solution would be to explicitly write versions for 2d and 3d, which will likely greatly help performance:
@numba.njit
def label_image_2d(decoded_image):
w, h = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
x, y = stack.pop()
if decoded_image[x, y] != decoded_image[idx] or labels[x, y] != 0:
continue # already visited or not part of this group
labels[x, y] = current_label
if x > 0: stack.append((x-1, y))
if x+1 < w: stack.append((x+1, y))
if y > 0: stack.append((x, y-1))
if y+1 < h: stack.append((x, y+1))
return labels
@numba.njit
def label_image_3d(decoded_image):
w, h, l = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
x, y, z = stack.pop()
if decoded_image[x, y, z] != decoded_image[idx] or labels[x, y, z] != 0:
continue # already visited or not part of this group
labels[x, y, z] = current_label
if x > 0: stack.append((x-1, y, z))
if x+1 < w: stack.append((x+1, y, z))
if y > 0: stack.append((x, y-1, z))
if y+1 < h: stack.append((x, y+1, z))
if z > 0: stack.append((x, y, z-1))
if z+1 < l: stack.append((x, y, z+1))
return labels
def label_image(decoded_image):
dim = len(decoded_image.shape)
if dim == 2:
return label_image_2d(decoded_image)
assert dim == 3
return label_image_3d(decoded_image)
Note also that the iterative solution doesn't suffer from stack limits: np.full((100,100,100), 1)
works just fine in the iterative solution, but fails in the recursive solution (segfaults if using numba).
Doing a very rudimentary benchmark of
for i in range(1, 10000):
label_image(np.full((20,20,20), i))
(many iterations to minimize the impact of JIT, could also do a few warmup runs, then start measuring time or similar)
The iterative solution seems to be several times faster (about 5x on my machine see below). You could probably optimize the recursive solution and get it to a comparable speed, f.e. by avoiding the temporary coords
list or by changing the np.where
to > 0
.
I don't know how well numba can optimize the zipped np.where
. For further optimization, you could consider (and benchmark) using explicit nested for x in range(0, w): for y in range(0, h):
loops there.
To remain competitive with the merge strategy proposed by Nick, I've optimized this further, picking some low hanging fruit:
zip
to explicit loops with continue
rather than np.where
.decoded_image[idx]
in a local variable (ideally shouldn't matter, but doesn't hurt).w*h
or w*h*l
respectively).@numba.njit
def label_image_2d(decoded_image):
w, h = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
stack = []
for sx in range(0, w):
for sy in range(0, h):
start = (sx, sy)
image_label = decoded_image[start]
if image_label <= 0 or labels[start] != 0:
continue
current_label += 1
stack.append(start)
while stack:
x, y = stack.pop()
if decoded_image[x, y] != image_label or labels[x, y] != 0:
continue # already visited or not part of this group
labels[x, y] = current_label
if x > 0: stack.append((x-1, y))
if x+1 < w: stack.append((x+1, y))
if y > 0: stack.append((x, y-1))
if y+1 < h: stack.append((x, y+1))
return labels
@numba.njit
def label_image_3d(decoded_image):
w, h, l = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
stack = []
for sx in range(0, w):
for sy in range(0, h):
for sz in range(0, l):
start = (sx, sy, sz)
image_label = decoded_image[start]
if image_label <= 0 or labels[start] != 0:
continue
current_label += 1
stack.append(start)
while stack:
x, y, z = stack.pop()
if decoded_image[x, y, z] != image_label or labels[x, y, z] != 0:
continue # already visited or not part of this group
labels[x, y, z] = current_label
if x > 0: stack.append((x-1, y, z))
if x+1 < w: stack.append((x+1, y, z))
if y > 0: stack.append((x, y-1, z))
if y+1 < h: stack.append((x, y+1, z))
if z > 0: stack.append((x, y, z-1))
if z+1 < l: stack.append((x, y, z+1))
return labels
I then cobbled together a benchmark to compare the four approaches (original recursive, old iterative, new iterative, merge-based), putting them in four different modules:
import numpy as np
import timeit
import rec
import iter_old
import iter_new
import merge
shape = (100, 100, 100)
n = 20
for module in [rec, iter_old, iter_new, merge]:
print(module)
label_image = module.label_image
# Trigger compilation of 2d & 3d functions
label_image(np.zeros((1, 1)))
label_image(np.zeros((1, 1, 1)))
i = 0
def test_full():
global i
i += 1
label_image(np.full(shape, i))
print("single group:", timeit.timeit(test_full, number=n))
print("random (few groups):", timeit.timeit(
lambda: label_image(np.random.randint(low = 1, high = 10, size = shape)),
number=n))
print("random (many groups):", timeit.timeit(
lambda: label_image(np.random.randint(low = 1, high = 400, size = shape)),
number=n))
print("only groups:", timeit.timeit(
lambda: label_image(np.arange(np.prod(shape)).reshape(shape)),
number=n))
This outputs something like
<module 'rec' from '...'>
single group: 32.39212468900041
random (few groups): 14.648884047001047
random (many groups): 13.304533919001187
only groups: 13.513677138000276
<module 'iter_old' from '...'>
single group: 10.287227957000141
random (few groups): 17.37535468200076
random (many groups): 14.506630064999626
only groups: 13.132202609998785
<module 'iter_new' from '...'>
single group: 7.388022166000155
random (few groups): 11.585243002000425
random (many groups): 9.560101995000878
only groups: 8.693653742000606
<module 'merge' from '...'>
single group: 14.657021331999204
random (few groups): 14.146574055999736
random (many groups): 13.412314713001251
only groups: 12.642367746000673
It seems to me that the improved iterative approach may be better. Note that the original rudimentary benchmark seems to be the worst case for the recursive variant. In general the difference isn't as large.
The tested array is pretty small (20³). If I test with a larger array (100³), and a smaller n (20), I get roughly the following results (rec
is omitted because due to stack limits, it would segfault):
<module 'iter_old' from '...'>
single group: 3.5357716739999887
random (few groups): 4.931695729999774
random (many groups): 3.4671142009992764
only groups: 3.3023930709987326
<module 'iter_new' from '...'>
single group: 2.45903080700009
random (few groups): 2.907660342001691
random (many groups): 2.309699692999857
only groups: 2.052835552000033
<module 'merge' from '...'>
single group: 3.7620838259990705
random (few groups): 3.3524249689999124
random (many groups): 3.126650959999097
only groups: 2.9456547739991947
The iterative approach still seems to be more efficient.