Search code examples
pythonnumpyconvolutionarray-broadcasting

Propagating true entries along axis in an array


I have to perform the operation below many times. Using numpy functions instead of loops I usually get a very good performance but I have not been able to replicate this for higher dimensional arrays. Any suggestion or alternative would be most welcome:

I have a boolean array and I would like to propagate the true indeces to the next 2 positions for example:

If this 1 dimensional array (A) is:

import numpy as np

# Number of positions to propagate the array
propagation = 2

# Original array
A = np.array([False, True, False, False, False, True, False, False, False, False, False, True, False])

I can create an "empty" array and then find the indices, propagate them, and then flatten argwhere and then flatten it:

B = np.zeros(A.shape, dtype=bool)

# Compute the indeces of the True values and make the two next to it True as well
idcs_true = np.argwhere(A) + np.arange(propagation + 1)
idcs_true = idcs_true.flatten()
idcs_true = idcs_true[idcs_true < A.size] # in case the error propagation gives a great
B[idcs_true] = True

# Array
print(f'Original array     A = {A}')
print(f'New array (2 true) B = {B}')

which gives:

Original array     A = [False  True False False False  True False False False False False  True
 False]
New array (2 true) B = [False  True  True  True False  True  True  True False False False  True
  True]

However, this becomes much more complex and fails if for example:

AA = np.array([[False, True, False, False, False, True, False, False, False, False, False, True, False],
               [False, True, False, False, False, True, False, False, False, False, False, True, False]])

Thanks for any advice.


Solution

  • I just leave here version so you can compare the speed against the proposed solution:

    import numba
    import numpy as np
    
    
    @numba.njit(parallel=True)
    def propagate_true_numba(arr, n=2):
        out = np.zeros_like(arr, dtype="uint8")
    
        for i in numba.prange(arr.shape[0]):
            prop = 0
            for j in range(arr.shape[1]):
                if arr[i, j] == 1:
                    prop = n
                    out[i, j] = 1
                elif prop:
                    prop -= 1
                    out[i, j] = 1
    
        return out
    

    Benchmark:

    import numba
    import numpy as np
    import perfplot
    
    
    @numba.njit(parallel=True)
    def propagate_true_numba(arr, n=2):
        out = np.zeros_like(arr, dtype="uint8")
    
        for i in numba.prange(arr.shape[0]):
            prop = 0
            for j in range(arr.shape[1]):
                if arr[i, j] == 1:
                    prop = n
                    out[i, j] = 1
                elif prop:
                    prop -= 1
                    out[i, j] = 1
    
        return out
    
    
    def _prop_func(A, propagation):
        B = np.zeros(A.shape, dtype=bool)
    
        # Compute the indices of the True values and make the two next to it True as well
        idcs_true = np.argwhere(A) + np.arange(propagation + 1)
        idcs_true = idcs_true.flatten()
        idcs_true = idcs_true[
            idcs_true < A.size
        ]  # in case the error propagation gives a great
        B[idcs_true] = True
        return B
    
    
    def propagate_true_numpy(arr, n=2):
        return np.apply_along_axis(_prop_func, 1, arr, n)
    
    
    AA = np.array([[False, True, False, False, False, True, False, False, False, False, False, True, False],
                 [False, True, False, False, False, True, False, False, False, False, False, True, False]])
    
    x = propagate_true_numba(AA, 2)
    y = propagate_true_numpy(AA, 2)
    assert np.allclose(x, y)
    
    np.random.seed(0)
    
    perfplot.show(
        setup=lambda n: np.random.randint(0, 2, size=(n, n), dtype="uint8"),
        kernels=[
            lambda arr: propagate_true_numpy(arr, 2),
            lambda arr: propagate_true_numba(arr, 2),
        ],
        labels=["numpy", "numba"],
        n_range=[10, 25, 50, 100, 250, 500, 1000, 2500, 5000],
        xlabel="N * N",
        logx=True,
        logy=True,
        equality_check=np.allclose,
    )
    

    Creates on my AMD 5700x this graph:

    enter image description here