Search code examples
pythonperformancenumpyscipydistance

Fastest code to calculate distance between points in 2D numpy array with cyclic (periodic) boundary conditions


I know how to calculate the Euclidean distance between points in an array using scipy.spatial.distance.cdist

Similar to answers to this question: Calculate Distances Between One Point in Matrix From All Other Points

However, I would like to make the calculation assuming cyclic boundary conditions, e.g. so that point [0,0] is distance 1 from point [0,n-1] in this case, not a distance of n-1. (I will then make a mask for all points within a threshold distance of my target cells, but that is not central to the question).

The only way I can think of is to repeat the calculation 9 times, with the domain indices having n added/subtracted in the x, y and then x&y directions, and then stacking the results and finding the minimum across the 9 slices. To illustrate the need for 9 repetitions, I put together a simple schematic with just 1 J-point, marked with a circle, and which shows an example where the cell marked by the triangle in this case has its nearest neighbour in the domain reflected to the top-left.

enter image description here

this is the code I developed for this using cdist:

import numpy as np
from scipy import spatial
    
n=5 # size of 2D box (n X n points)
np.random.seed(1) # to make reproducible
a=np.random.uniform(size=(n,n)) 
i=np.argwhere(a>-1)  # all points, for each loc we want distance to nearest J 
j=np.argwhere(a>0.85) # set of J locations to find distance to.

# this will be used in the KDtree soln 
global maxdist
maxdist=2.0

def dist_v1(i,j):
    dist=[]
    # 3x3 search required for periodic boundaries.
    for xoff in [-n,0,n]:
        for yoff in [-n,0,n]:
            jo=j.copy()
            jo[:,0]-=xoff
            jo[:,1]-=yoff
            dist.append(np.amin(spatial.distance.cdist(i,jo,metric='euclidean'),1)) 
    dist=np.amin(np.stack(dist),0).reshape([n,n])
    return(dist)

This works, and produces e.g. :

print(dist_v1(i,j))


[[1.41421356 1.         1.41421356 1.41421356 1.        ]
 [2.23606798 2.         1.41421356 1.         1.41421356]
 [2.         2.         1.         0.         1.        ]
 [1.41421356 1.         1.41421356 1.         1.        ]
 [1.         0.         1.         1.         0.        ]]

The zeros obviously mark the J points, and the distances are correct (this EDIT corrects my earlier attempts which was incorrect).

Note that if you change the last two lines to stack the raw distances and then only use one minimum like this :

def dist_v2(i,j):
    dist=[]
    # 3x3 search required for periodic boundaries.
    for xoff in [-n,0,n]:
        for yoff in [-n,0,n]:
            jo=j.copy()
            jo[:,0]-=xoff
            jo[:,1]-=yoff
            dist.append(spatial.distance.cdist(i,jo,metric='euclidean')) 
    dist=np.amin(np.dstack(dist),(1,2)).reshape([n,n])
    return(dist)

it is faster for small n (<10) but considerably slower for larger arrays (n>10)

...but either way, it is slow for my large arrays (N=500 and J points number around 70), this search is taking up about 99% of the calculation time, (and it is a bit ugly too using the loops) - is there a better/faster way?

The other options I thought of were:

  1. scipy.spatial.KDTree.query_ball_point

With further searching I have found that there is a function scipy.spatial.KDTree.query_ball_point which directly calculates the coordinates within a radius of my J-points, but it doesn't seem to have any facility to use periodic boundaries, so I presume one would still need to somehow use a 3x3 loop, stack and then use amin as I do above, so I'm not sure if this will be any faster.

I coded up a solution using this function WITHOUT worrying about the periodic boundary conditions (i.e. this doesn't answer my question)

def dist_v3(n,j):
    x, y = np.mgrid[0:n, 0:n]
    points = np.c_[x.ravel(), y.ravel()]
    tree=spatial.KDTree(points)
    mask=np.zeros([n,n])
    for results in tree.query_ball_point((j), maxdist):
        mask[points[results][:,0],points[results][:,1]]=1
    return(mask)

Maybe I'm not using it in the most efficient way, but this is already as slow as my cdist-based solutions even without the periodic boundaries. Including the mask function in the two cdist solutions, i.e. replacing the return(dist) with return(np.where(dist<=maxdist,1,0)) in those functions, and then using timeit, I get the following timings for n=100:

from timeit import timeit

print("cdist v1:",timeit(lambda: dist_v1(i,j), number=3)*100)
print("cdist v2:",timeit(lambda: dist_v2(i,j), number=3)*100)
print("KDtree:", timeit(lambda: dist_v3(n,j), number=3)*100)

cdist v1: 181.80927299981704
cdist v2: 554.8205785999016
KDtree: 605.119637199823
  1. Make an array of relative coordinates for points within a set distance of [0,0] and then manually loop over the J points setting up a mask with this list of relative points - This has the advantage that the "relative distance" calculation is only performed once (my J points change each timestep), but I suspect the looping will be very slow.

  2. Precalculate a set of masks for EVERY point in the 2D domain, so in each timestep of the model integration I just pick out the mask for the J-point and apply. This would use a LOT of memory (proportional to n^4) and perhaps is still slow as you need to loop over J points to combine the masks.


Solution

  • [EDIT] - I found a mistake in the way the code keeps track of the points where the job is done, fixed it with the mask_kernel. The pure python version of the newer code is ~1.5 times slower, but the numba version is slightly faster (due to some other optimisations).

    [current best : ~100xto 120x the original speed]

    First of all, thank you for submitting this problem, I had a lot of fun optimizing it!

    My current best solution relies on the assumption that the grid is regular and that the "source" points (the ones from which we need to compute the distance) are roughly evenly distributed.

    The idea here is that all of the distances are going to be either 1, sqrt(2), sqrt(3), ... so we can do the numerical calculation beforehand. Then we simply put these values in a matrix and copy that matrix around each source point (and making sure to keep the minimum value found at each point). This covers the vast majority of the points (>99%). Then we apply another more "classical" method for the remaining 1%.

    Here's the code:

    import numpy as np
    
    def sq_distance(x1, y1, x2, y2, n): 
        # computes the pairwise squared distance between 2 sets of points (with periodicity)
        # x1, y1 : coordinates of the first set of points (source)
        # x2, y2 : same
        dx = np.abs((np.subtract.outer(x1, x2) + n//2)%(n) - n//2)
        dy = np.abs((np.subtract.outer(y1, y2) + n//2)%(n) - n//2)
        d  = (dx*dx + dy*dy)
        return d
    
    def apply_kernel(sources, sqdist, kern_size, n, mask):
        ker_i, ker_j = np.meshgrid(np.arange(-kern_size, kern_size+1), np.arange(-kern_size, kern_size+1), indexing="ij")
        kernel = np.add.outer(np.arange(-kern_size, kern_size+1)**2, np.arange(-kern_size, kern_size+1)**2)
        mask_kernel = kernel > kern_size**2
    
        for pi, pj in sources:
            ind_i = (pi+ker_i)%n
            ind_j = (pj+ker_j)%n
            sqdist[ind_i,ind_j] = np.minimum(kernel, sqdist[ind_i,ind_j])
            mask[ind_i,ind_j] *= mask_kernel
    
    def dist_vf(sources, n, kernel_size):
        sources = np.asfortranarray(sources) #for memory contiguity
    
        kernel_size = min(kernel_size, n//2)
        kernel_size = max(kernel_size, 1)
    
        sqdist = np.full((n,n), 10*n**2, dtype=np.int32) #preallocate with a huge distance (>max**2)
        mask   = np.ones((n,n), dtype=bool)              #which points have not been reached?
    
        #main code
        apply_kernel(sources, sqdist, kernel_size, n, mask) 
    
        #remaining points
        rem_i, rem_j = np.nonzero(mask)
        if len(rem_i) > 0:
            sq_d = sq_distance(sources[:,0], sources[:,1], rem_i, rem_j, n).min(axis=0)
            sqdist[rem_i, rem_j] = sq_d
    
        #eff = 1-rem_i.size/n**2
        #print("covered by kernel :", 100*eff, "%")
        #print("overlap :", sources.shape[0]*(1+2*kernel_size)**2/n**2)
        #print()
    
        return np.sqrt(sqdist)
    
    

    Testing this version with

    n=500  # size of 2D box (n X n points)
    np.random.seed(1) # to make reproducible
    a=np.random.uniform(size=(n,n)) 
    all_points=np.argwhere(a>-1)  # all points, for each loc we want distance to nearest J 
    source_points=np.argwhere(a>1-70/n**2) # set of J locations to find distance to.
    
    #
    # code for dist_v1 and dist_vf
    #
    
    overlap=5.2
    kernel_size = int(np.sqrt(overlap*n**2/source_points.shape[0])/2)
    
    print("cdist v1      :", timeit(lambda: dist_v1(all_points,source_points), number=1)*1000, "ms")
    print("kernel version:", timeit(lambda: dist_vf(source_points, n, kernel_size), number=10)*100, "ms")
    
    

    gives

    cdist v1      : 1148.6694 ms
    kernel version: 69.21876999999998 ms
    

    which is a already a ~17x speedup! I also implemented a numba version of sq_distance and apply_kernel: [this is the new correct version]

    @njit(cache=True)
    def sq_distance(x1, y1, x2, y2, n):
        m1 = x1.size
        m2 = x2.size
        n2 = n//2
        d = np.empty((m1,m2), dtype=np.int32)
        for i in range(m1):
            for j in range(m2):
                dx = np.abs(x1[i] - x2[j] + n2)%n - n2
                dy = np.abs(y1[i] - y2[j] + n2)%n - n2
                d[i,j]  = (dx*dx + dy*dy)
        return d
    
    @njit(cache=True)
    def apply_kernel(sources, sqdist, kern_size, n, mask):
        # creating the kernel
        kernel = np.empty((2*kern_size+1, 2*kern_size+1))
        vals = np.arange(-kern_size, kern_size+1)**2
        for i in range(2*kern_size+1):
            for j in range(2*kern_size+1):
                kernel[i,j] = vals[i] + vals[j]
        mask_kernel = kernel > kern_size**2
    
        I = sources[:,0]
        J = sources[:,1]
    
        # applying the kernel for each point
        for l in range(sources.shape[0]):
            pi = I[l]
            pj = J[l]
    
            if pj - kern_size >= 0 and pj + kern_size<n: #if we are in the middle, no need to do the modulo for j
                for i in range(2*kern_size+1):
                    ind_i = np.mod((pi+i-kern_size), n)
                    for j in range(2*kern_size+1):
                        ind_j = (pj+j-kern_size)
                        sqdist[ind_i,ind_j] = np.minimum(kernel[i,j], sqdist[ind_i,ind_j])
                        mask[ind_i,ind_j] = mask_kernel[i,j] and mask[ind_i,ind_j]
    
            else:
                for i in range(2*kern_size+1):
                    ind_i = np.mod((pi+i-kern_size), n)
                    for j in range(2*kern_size+1):
                        ind_j = np.mod((pj+j-kern_size), n)
                        sqdist[ind_i,ind_j] = np.minimum(kernel[i,j], sqdist[ind_i,ind_j])
                        mask[ind_i,ind_j] = mask_kernel[i,j] and mask[ind_i,ind_j]
        return
    
    

    and testing with

    overlap=5.2
    kernel_size = int(np.sqrt(overlap*n**2/source_points.shape[0])/2)
    
    print("cdist v1                :", timeit(lambda: dist_v1(all_points,source_points), number=1)*1000, "ms")
    print("kernel numba (first run):", timeit(lambda: dist_vf(source_points, n, kernel_size), number=1)*1000, "ms") #first run = cimpilation = long
    print("kernel numba            :", timeit(lambda: dist_vf(source_points, n, kernel_size), number=10)*100, "ms")
    

    which gave the following results

    cdist v1                : 1163.0742 ms
    kernel numba (first run): 2060.0802 ms
    kernel numba            : 8.80377000000001 ms
    

    Due to the JIT compilation, the first run is pretty slow but otherwise, it's a 120x improvement!

    It may be possible to get a little bit more out of this algorithm by tweaking the kernel_size parameter (or the overlap). The current choice of kernel_size is only effective for a small number of source points. For example, this choice fails miserably with source_points=np.argwhere(a>0.85) (13s) while manually setting kernel_size=5 gives the answer in 22ms.

    I hope my post isn't (unnecessarily) too complicated, I don't really know how to organise it better.

    [EDIT 2]:

    I gave a little more attention to the non-numba part of the code and managed to get a pretty significant speedup, getting very close to what numba could achieve: Here is the new version of the function apply_kernel:

    def apply_kernel(sources, sqdist, kern_size, n, mask):
        ker_i = np.arange(-kern_size, kern_size+1).reshape((2*kern_size+1,1))
        ker_j = np.arange(-kern_size, kern_size+1).reshape((1,2*kern_size+1))
    
        kernel = np.add.outer(np.arange(-kern_size, kern_size+1)**2, np.arange(-kern_size, kern_size+1)**2)
        mask_kernel = kernel > kern_size**2
    
        for pi, pj in sources:
    
            imin = pi-kern_size
            jmin = pj-kern_size
            imax = pi+kern_size+1
            jmax = pj+kern_size+1
            if imax < n and jmax < n and imin >=0 and jmin >=0: # we are inside
                sqdist[imin:imax,jmin:jmax] = np.minimum(kernel, sqdist[imin:imax,jmin:jmax])
                mask[imin:imax,jmin:jmax] *= mask_kernel
    
            elif imax < n and imin >=0:
                ind_j = (pj+ker_j.ravel())%n
                sqdist[imin:imax,ind_j] = np.minimum(kernel, sqdist[imin:imax,ind_j])
                mask[imin:imax,ind_j] *= mask_kernel
    
            elif jmax < n and jmin >=0:
                ind_i = (pi+ker_i.ravel())%n
                sqdist[ind_i,jmin:jmax] = np.minimum(kernel, sqdist[ind_i,jmin:jmax])
                mask[ind_i,jmin:jmax] *= mask_kernel
    
            else :
                ind_i = (pi+ker_i)%n
                ind_j = (pj+ker_j)%n
                sqdist[ind_i,ind_j] = np.minimum(kernel, sqdist[ind_i,ind_j])
                mask[ind_i,ind_j] *= mask_kernel
    

    The main optimisations are

    • Indexing with slices (rather than a dense array)
    • Use of sparse indexes (how did I not think about that earlier)

    Testing with

    overlap=5.4
    kernel_size = int(np.sqrt(overlap*n**2/source_points.shape[0])/2)
    
    print("cdist v1  :", timeit(lambda: dist_v1(all_points,source_points), number=1)*1000, "ms")
    print("kernel v2 :", timeit(lambda: dist_vf(source_points, n, kernel_size), number=10)*100, "ms")
    

    gives

    cdist v1  : 1209.8163000000002 ms
    kernel v2 : 11.319049999999997 ms
    

    which is a nice 100x improvement over cdist, a ~5.5x improvement over the previous numpy-only version and just ~25% slower than what I could achieve with numba.