Search code examples
pythonnumpyvectorizationnumpy-ndarraydistance-matrix

Why are np.hypot and np.subtract.outer very fast compared to vanilla broadcast and are there faster ways to calculate distance matrix?


I have two large sets of 2D points and need to calculate a distance matrix. I need it to be fast so I used NumPy broadcasting. Of two ways to calculate distance matrix I don't understand why one is better than the other.

From here I have contradicting results. Cells [3, 4, 6] and [8, 9] both calculate the distance matrix, but 3+4 uses subtract.outer faster than 8 which uses broadcasting and 6 uses hypot faster than 9, which is the simple way. I did not try Python loops assuming it will never finish.

  1. Is there a faster way to calculate distance matrix?
  2. Why hypot and subtract.outer are faster?

Code (I change seed to prevent cache re-use):

### Cell 1
import numpy as np

np.random.seed(858442)

### Cell 2
%%time
obs = np.random.random((50000, 2))
interp = np.random.random((30000, 2))

CPU times: user 2.02 ms, sys: 1.4 ms, total: 3.42 ms
Wall time: 1.84 ms

### Cell 3
%%time
d0 = np.subtract.outer(obs[:,0], interp[:,0])

CPU times: user 2.46 s, sys: 1.97 s, total: 4.42 s
Wall time: 4.42 s

### Cell 4
%%time
d1 = np.subtract.outer(obs[:,1], interp[:,1])

CPU times: user 3.1 s, sys: 2.7 s, total: 5.8 s
Wall time: 8.34 s

### Cell 5
%%time
h = np.hypot(d0, d1)

CPU times: user 12.7 s, sys: 24.6 s, total: 37.3 s
Wall time: 1min 6s

### Cell 6
np.random.seed(773228)

### Cell 7
%%time
obs = np.random.random((50000, 2))
interp = np.random.random((30000, 2))

CPU times: user 1.84 ms, sys: 1.56 ms, total: 3.4 ms
Wall time: 2.03 ms

### Cell 8
%%time
d = obs[:, np.newaxis, :] - interp
d0, d1 = d[:, :, 0], d[:, :, 1]

CPU times: user 22.7 s, sys: 8.24 s, total: 30.9 s
Wall time: 33.2 s

### Cell 9
%%time
h = np.sqrt(d0**2 + d1**2)

CPU times: user 29.1 s, sys: 2min 12s, total: 2min 41s
Wall time: 6min 10s

Solution

  • First of all, d0 and d1 takes each 50000 x 30000 x 8 = 12 GB which is pretty big. Make sure you have more than 100 GB of memory because this is what the whole script requires! This is a huge amount of memory. If you do not have enough memory, the operating system will use a storage device (eg. swap) to store excess data which is much slower. Actually, there is no reason Cell-4 is slower than Cell-3 and I guess that you already do not have enough memory to (fully) store d1 in RAM while d0 seems to fit (mostly) in memory. There is not difference on my machine when both can fit in RAM (one can also reverse the order of the operations to check this). This also explain why further operation tends to get slower.

    That being said, Cells 8+9 are also slower because they create temporary arrays and need more memory passes to compute the result than Cells 3+4+5. Indeed, the expression np.sqrt(d0**2 + d1**2) first compute d0**2 in memory resulting in a new 12 GB temporary array, then compute d1**2 resulting in another 12 GB temporary array, then perform the sum of the two temporary array to produce another new 12 GB temporary array, and finally compute the square-root resulting in another 12 GB temporary array. This can required up to 48 GB of memory and require 4 read-write memory-bound passes. This is not efficient and do not use the CPU/RAM efficiently (eg. CPU cache).

    There is a much faster implementation consisting in doing the whole computation in 1 pass and in parallel using the Numba's JIT. Here is an example:

    import numba as nb
    @nb.njit(parallel=True)
    def distanceMatrix(a, b):
        res = np.empty((a.shape[0], b.shape[0]), dtype=a.dtype)
        for i in nb.prange(a.shape[0]):
            for j in range(b.shape[0]):
                res[i, j] = np.sqrt((a[i, 0] - b[j, 0])**2 + (a[i, 1] - b[j, 1])**2)
        return res
    

    This implementation use 3 times less memory (only 12 GB) and is much faster than the one using subtract.outer. Indeed, due to swapping, Cell 3+4+5 takes few minutes while this one takes 1.3 second!

    The takeaway is that memory accesses are expensive as well as temporary array. One need to avoid using multiple passes in memory while working on huge buffers and take advantage of CPU caches when the computation performed is not trivial (for example by using array chunks).