Search code examples
pythonnumpycycle

Optimizing nested iterators in python


Say, that fig1 and fig2 are numpy 2d arrays of shape (n1, 2) and (n2, 2), which basically are lists of coordinates. What I'm tying to do is to count a mean Hausdorff distance between two figures presented as sets of dots. Here is what I've came up with:

@jit(nopython=True, fastmath=True)   
def dist(fig1, fig2):
    n1 = fig1.shape[0]
    n2 = fig2.shape[0]
    dist_matrix = np.zeros((n1, n2))
    min1 = np.full(n1, np.inf)
    min2 = np.full(n2, np.inf)

    for i in range(n1):
        for j in range(n2):
            dist_matrix[i, j] = np.sqrt((fig1[i, 0] - fig2[j, 0])**2 + (fig1[i, 1] - fig2[j, 1])**2)
            min1[i] = np.minimum(min1[i], dist_matrix[i, j])
            min2[j] = np.minimum(min2[j], dist_matrix[i, j])

    d1 = np.mean(min1)
    d2 = np.mean(min2)
    h_dist = np.maximum(d1, d2)
    return h_dist

That works perfectly, but I'd like to avoid these nested for's, hopefully via vectorization. Can somebody suggest any possible variants?

If needed I can give some generators of figures.


Solution

  • Two Methods

    1. Using KDTree
    2. Numba with posted approach

    Code

    # Method 1--KDTree
    import numpy as np
    from sklearn.neighbors import KDTree
    
    def dist_kdtree(fig1, fig2):
        ' Distance using KDTree '
    
        # KDTree data structures for fig1 & fig2
        # which provides for nearest neighbor queries
        fig1_tree = KDTree(fig1)
        fig2_tree = KDTree(fig2)
    
        # Nearest neighbors of each point of fig1 in fig2
        nearest_dist1, nearest_ind1 = fig2_tree.query(fig1, k=1)
    
        # Nearest neighbors of each point of fig2 in fig1
        nearest_dist2, nearest_ind2 = fig1_tree.query(fig2, k=1)
    
        # Mean of distancess
        d1 = np.mean(nearest_dist1)
        d2 = np.mean(nearest_dist2)
        return np.maximum(d1, d2)
    
    # Method 2--Numba on posted approach
    import numpy as np
    from numba import jit
    
    @jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
    def dist_numba(fig1, fig2):
        n1 = fig1.shape[0]
        n2 = fig2.shape[0]
        dist_matrix = np.zeros((n1, n2))
        min1 = np.full(n1, np.inf)
        min2 = np.full(n2, np.inf)
    
        for i in range(n1):
            for j in range(n2):
                dist_matrix[i, j] = np.sqrt((fig1[i, 0] - fig2[j, 0])**2 + (fig1[i, 1] - fig2[j, 1])**2)
                min1[i] = np.minimum(min1[i], dist_matrix[i, j])
                min2[j] = np.minimum(min2[j], dist_matrix[i, j])
    
        d1 = np.mean(min1)
        d2 = np.mean(min2)
        h_dist = np.maximum(d1, d2)
        return h_dist
    
    # Baseline--Original version
    import numpy as np
    
    def dist(fig1, fig2):
        n1 = fig1.shape[0]
        n2 = fig2.shape[0]
        dist_matrix = np.zeros((n1, n2))
        min1 = np.full(n1, np.inf)
        min2 = np.full(n2, np.inf)
    
        for i in range(n1):
            for j in range(n2):
                dist_matrix[i, j] = np.sqrt((fig1[i, 0] - fig2[j, 0])**2 + (fig1[i, 1] - fig2[j, 1])**2)
                min1[i] = np.minimum(min1[i], dist_matrix[i, j])
                min2[j] = np.minimum(min2[j], dist_matrix[i, j])
    
        d1 = np.mean(min1)
        d2 = np.mean(min2)
        h_dist = np.maximum(d1, d2)
        return h_dist
    

    Performance Test

    # Create two random 2D arrays
    np.random.seed(0)
    fig1 = np.random.random((100, 2))
    fig2 = np.random.random((500, 2))
    
    # Time original
    %timeit dist(fig1, fig2)
    Out: 815 ms ± 49.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    # Time KDTree
    %timeit dist_kdtree(fig1, fig2)
    Out: 1.66 ms ± 88.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    # Time Numba
    %timeit dist_numba(fig1, fig2)
    Out: 770 µs ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    Summary

    • KDTree provided 490 speed up over original (1.6 ms vs. 815 ms)
    • Numba provided a 1K speed up over original (770 us vs. 815 ms)