Search code examples
pythonarraysnumpycoordinatescoordinate-transformation

Speed up numpy looking for best indices


I have a numpy array that maps x-y-coordinates to the appropriate z-coordinates. For this I use a 2D array that represents x and y as its axes and contains the corresponding z values:

import numpy as np
x_size = 2000
y_size = 2500
z_size = 400
rng = np.random.default_rng(123)
z_coordinates = np.linspace(0, z_size, y_size) + rng.laplace(0, 1, (x_size, y_size))

So each of the 2000*2500 x-y-points is assigned a z-value (float between 0 and 400). Now I want to look up for each integer z and integer x which is the closest y-value, essentially creating a map that is of shape (x_size, z_size) and holds the best y-values.

The simplest approach is creating an empty array of target shape and iterating over each z value:

y_coordinates = np.empty((x_size, z_size), dtype=np.uint16)
for i in range(z_size):
    y_coordinates[:, i] = np.argmin(
        np.abs(z_coordinates - i),
        axis=1,
    )

however this takes about 11 s on my machine, which unfortunately is way to slow.

Surely using a more vectorised approach would be faster, such as:

y_coordinates = np.argmin(
    np.abs(
        z_coordinates[..., np.newaxis] - np.arange(z_size)
    ),
    axis=1,
)

Surprisingly this runs about 60% slower than the version above (tested at 1/10th size, since at full size this uses excessive memory).

Also wrapping the code blocks in functions and decorating them with numba's @jit(nopython=True) doesn't help.

How can I speed up the calculation?


Solution

  • This answer provide an algorithm with an optimal complexity: O(x_size * (y_size + z_size)). This algorithm is the fastest one proposed so far (by a large margin). It is implemented in Numba using multiple threads.


    Explanation of the approach

    The idea is that there is no need to iterate over all Z values : we can iterate over z_coordinates line by line, and for each line of z_coordinates, we fill an array used to find the nearest value for each possible z. The best candidate for the value z is stored in arr[z].

    In practice, there are tricky corner cases making things a more complicated. For example, due to rounding, I decided to fill the neighbours of arr[z] (i.e. arr[z-1] and arr[z+1]) so to make the algorithm simpler. Moreover, when there are not enough values so arr cannot be fully filled by all the values in a line of z_coordinates, we need to fill the holes in the arr. In some more complicated cases (combining rounding issue while kind of holes in arr), we need to correct the values in arr (or operate on more distant neighbours which is not efficient). The number of step in the correction function should always be a small constant, certainly <= 3 (it nerver reached 3 in practice in my tests). Note that, in practice, no corner case happens on the specific input dataset provided.

    Each line is computed in parallel using multiple threads. I assume the array is not too small (to avoid to deal with more corner cases in the code and make it simpler) which should not be an issue. I also assume there are no special values like NaN in z_coordinates.


    Resulting code

    Here is the final code:

    import numba as nb
    import numpy as np
    
    # Fill the missing values in the value-array if there is not enough values (e.g. pretty large z_size)
    # (untested)
    @nb.njit('(float64[::1], uint16[::1], int64)')
    def fill_missing_values(all_val, all_pos, z_size):
        i = 0
        while i < z_size:
            # If there is a missing value
            if all_pos[i] == 0xFFFF:
                j = i
                while j < z_size and all_pos[j] == 0xFFFF:
                    j += 1
                if i == 0:
                    # Fill the hole based on 1 value (lower bound)
                    assert j+1 < z_size and all_pos[j] == 0xFFFF and all_pos[j] != 0xFFFF
                    for i2 in range(i, j):
                        all_val[i2] = all_val[j+1]
                        all_pos[i2] = all_pos[j+1]
                elif j == z_size:
                    # Fill the hole based on 1 value (upper bound)
                    assert i-1 >= 0 and all_pos[i-1] != 0xFFFF and all_pos[i] == 0xFFFF
                    for i2 in range(i, j):
                        all_val[i2] = all_val[i-1]
                        all_pos[i2] = all_pos[i-1]
                else:
                    assert i-1 >= 0 and j < z_size and all_pos[i-1] != 0xFFFF and all_pos[j] != 0xFFFF
                    lower_val = all_val[i-1]
                    lower_pos = all_pos[i-1]
                    upper_val = all_val[j]
                    upper_pos = all_pos[j]
                    # Fill the hole based on 2 values
                    for i2 in range(i, j):
                        if np.abs(lower_val - i2) < np.abs(upper_val - i2):
                            all_val[i2] = lower_val
                            all_pos[i2] = lower_pos
                        else:
                            all_val[i2] = upper_val
                            all_pos[i2] = upper_pos
                i = j
            i += 1
    
    # Correct values in very pathological cases where z_size is big so there are not enough 
    # values added to the value-array causing some values of the value-array to be incorrect.
    # The number of `while` iteration should be always <= 3 in practice
    @nb.njit('(float64[::1], uint16[::1], int64)')
    def correct_values(all_val, all_pos, z_size):
        while True:
            stop = True
            for i in range(0, z_size-1):
                current = np.abs(all_val[i] - i)
                if np.abs(all_val[i+1] - i) < current:
                    all_val[i] = all_val[i+1]
                    all_pos[i] = all_pos[i+1]
                    stop = False
            for i in range(1, z_size):
                current = np.abs(all_val[i] - i)
                if np.abs(all_val[i-1] - i) < current:
                    all_val[i] = all_val[i-1]
                    all_pos[i] = all_pos[i-1]
                    stop = False
            if stop:
                break
    
    @nb.njit('(float64[:,::1], int64)', parallel=True)
    def compute_fastest(z_coordinates, z_size):
        x_size, y_size = z_coordinates.shape
        assert y_size >= 2 and z_size >= 2
        y_coordinates = np.empty((x_size, z_size), dtype=np.uint16)
        for x in nb.prange(x_size):
            all_pos = np.full(z_size, 0xFFFF, dtype=np.uint16)
            all_val = np.full(z_size, np.inf, dtype=np.float64)
            for y in range(0, y_size):
                val = z_coordinates[x, y]
                #assert not np.isnan(val)
                if val < 0: # Lower bound
                    i = 0
                    if np.abs(val - i) < np.abs(all_val[i] - i):
                        all_val[i] = val
                        all_pos[i] = y
                elif val >= z_size: # Upper bound
                    i = z_size - 1
                    if np.abs(val - i) < np.abs(all_val[i] - i):
                        all_val[i] = val
                        all_pos[i] = y
                else: # Inside the array of values
                    offset = np.int32(val)
                    for i in range(max(offset-1, 0), min(offset+2, z_size)):
                        if np.abs(val - i) < np.abs(all_val[i] - i):
                            all_val[i] = val
                            all_pos[i] = y
            fill_missing_values(all_val, all_pos, z_size)
            correct_values(all_val, all_pos, z_size)
            for i in range(0, z_size):
                y_coordinates[x, i] = all_pos[i]
        return y_coordinates
    

    Performance results

    Here are performance results on my machine with a i5-9600KF CPU (6 cores), Numpy 1.24.3, Numba 58.1, on Windows, for the provided input:

    Naive fully vectorized code in the question:   113000 ms  (slow due to swapping)
    Naive loop in the question:                      8460 ms
    ZLi's implementation:                            1964 ms
    Naive Numba parallel code with loops:             402 ms
    PaulS' implementation:                            262 ms
    This Numba code:                                   12 ms  <----------
    

    Note the fully-vectorized code in the question use so much memory it cause memory swapping. It completely saturate my 32 GiB of RAM (about 24 GiB was available in practice) which is clearly not reasonable!

    Note the PaulS' implementation is about equally fast with 32-bit and 64-bit on my machine. This is probably because the operation is compute-bound on my machine (dependent of the speed of the RAM).

    This Numba implementation is 705 times faster than the fastest implementation in the question. It is also 22 times faster than the best answer so far! It also use a tiny amount of additional RAM for the computation (<1 MiB).