pythonnumpynumbamahalanobis# Improve performance speed on batched mahalanobis distance computation

# Edit - Final implementation

## Implementation

## Results

I have the following piece of code that computes mahalanobis distance over a set of batched features, on my device it takes around 100ms, most of it it's due to the matrix multiplication between delta and inv_covariance

delta is a matrix of dimension 874x32x100, inv_covariance is of dimension 874x100x100

```
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
# calculate mahalanobis distances
delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
distances = ((delta @ inv_covariance) * delta).sum(2).transpose(1, 0)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
```

I've tried to convert the code to use numba and @njit, I've preallocated the intermediate matrix and I'm trying to perform smaller matrix multiplication using a for loop since matmul is not supported for 3 dimensional matrices.

```
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
# calculate mahalanobis distances
delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
inv_covariance = np.ascontiguousarray(inv_covariance)
intermediate_matrix = np.zeros_like(delta)
for i in range(intermediate_matrix.shape[0]):
intermediate_matrix[i] = delta[i] @ inv_covariance[i]
distances = (intermediate_matrix * delta).sum(2).transpose(1, 0)
distances = np.ascontiguousarray(distances)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
```

I added a few ascontiguousarray, the last one is important or the code doesn't work, the others are put to suppress the warning saying that @ will perform faster (it doesn't seem too much).

Is there a way to make the code faster, either by improving this or rethinking it in a different mathematical way?

Based on Jérôme Richard answer I ended up with this code

```
@nb.njit()
def matmul(delta: np.ndarray, inv_covariance: np.ndarray):
"""Computes distances = ((delta[i] @ inv_covariance[i]) * delta[i]).sum(2) using numba.
Args:
delta: Matrix of dimension BxD
inv_covariance: Matrix of dimension DxD
Returns:
Matrix of dimension BxD
"""
si, sj, sk = delta.shape[0], inv_covariance.shape[1], delta.shape[1]
assert sk == inv_covariance.shape[0]
line = np.zeros(sj, dtype=delta.dtype)
res = np.zeros(si, dtype=delta.dtype)
for i in range(si):
line.fill(0.0)
for k in range(sk):
factor = delta[i, k]
for j in range(sj):
line[j] += factor * inv_covariance[k, j]
for j in range(sj):
res[i] += line[j] * delta[i, j]
return res
@nb.njit
def mean_subtraction(embeddings: np.ndarray, mean: np.ndarray):
"""Computes embeddings - mean using numba, this is required as I have errors with the default numpy
implementation.
Args:
embeddings: Embedding matrix of dimension FxBxD
mean: Mean matrix of dimension BxD
Returns:
Delta matrix of dimension FxBxD
"""
output_matrix = np.zeros_like(embeddings)
for i in range(embeddings.shape[0]):
output_matrix[i] = embeddings[i] - mean
return output_matrix
@nb.njit(parallel=True)
def compute_distance_numba(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
"""Compute distance score using numba.
Args:
embedding: Embedding Vector
mean: Mean of the multivariate Gaussian distribution
inv_covariance: Inverse Covariance matrix of the multivariate Gaussian distribution.
"""
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
delta = np.ascontiguousarray(mean_subtraction(embedding, mean).transpose(2, 0, 1))
inv_covariance = np.ascontiguousarray(inv_covariance)
intermediate_matrix = np.zeros((delta.shape[0], delta.shape[1]), dtype=delta.dtype)
for i in nb.prange(intermediate_matrix.shape[0]):
intermediate_matrix[i] = matmul(delta[i], inv_covariance[i])
distances = intermediate_matrix.transpose(1, 0)
distances = np.ascontiguousarray(distances)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
```

Changes compared to the accepted answer are the custom function for subtraction and the addition of dtype for the intermediate matrix to avoid default np.float64.

Solution

First of all, matrix multiplications are done by libraries called BLAS and most implementation are efficient parallel ones. That being said, for batch of small matrices, the parallel implementation cannot be so efficient. Indeed, the granularity is too small so the overhead of using multiple thread becomes significant. It is **better to parallelize the outer-loop and use a sequential matrix-multiplication code**.

Since the matrices involved in the matrix multiplication are pretty small, it is better to **reimplement the matrix-multiplication manually**. Indeed, this removes the overhead of calling functions of the matrix multiplication library (BLAS) and also ensure no threads are used during the matrix-multiplication. One need to care about reading/writing values contiguously though so the operation is **SIMD-friendly**.

On top of that, **the matrix-multiplication can be merged with the next line** `(intermediate_matrix * delta).sum(2)`

so to write a smaller output array and avoid reading big temporary arrays back. This is critical since *the RAM is slow*. This strategy also reduce the memory footprint while being faster and scale better. It is certainly a good idea to also merge the operation with the line `(embedding - mean).transpose(2, 0, 1)`

though I did not test it.

Here is an implementation considering all the point except the very last one:

```
@nb.njit()
def matmul(delta, inv_covariance):
si, sj, sk = delta.shape[0], inv_covariance.shape[1], delta.shape[1]
assert sk == inv_covariance.shape[0]
line = np.zeros(sj, dtype=delta.dtype)
res = np.zeros(si, dtype=delta.dtype)
for i in range(si):
line.fill(0.0)
for k in range(sk):
factor = delta[i, k]
for j in range(sj):
line[j] += factor * inv_covariance[k, j]
for j in range(sj):
res[i] += line[j] * delta[i, j]
return res
@nb.njit(parallel=True)
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
# calculate mahalanobis distances
delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
inv_covariance = np.ascontiguousarray(inv_covariance)
intermediate_matrix = np.zeros((delta.shape[0], delta.shape[1]))
for i in nb.prange(intermediate_matrix.shape[0]):
intermediate_matrix[i] = matmul(delta[i], inv_covariance[i])
distances = intermediate_matrix.transpose(1, 0)
distances = np.ascontiguousarray(distances)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
```

The is about **3 times faster** on my i5-9600KF CPU (6-cores). Most of the time appears to be spent in the first line which can also be merged for better performance (assuming the array strides are reasonable). Note the compilation time is not included in the timings and results are equal (based on `np.allclose`

).

- Python Jinja2 LaTeX Table
- Getting attributes of a class
- How can I print many significant figures in Python?
- How to allow list append() method to return the new list
- Calculate Last Friday of Month in Pandas
- Python type hint for Iterable[str] that isn't str
- How to iterate over a list in chunks
- How to exit the entire application from a Python thread?
- Running shell command and capturing the output
- How do I pass a variable by reference?
- Convert range(r) to list of strings of length 2 in python
- How can I get the start and end dates for each week?
- how to use send_message() in python-telegram-bot
- Python conditional replacement based on element type
- How can I count the number of items in an arbitrary iterable (such as a generator)?
- Find longest consecutive range of numbers in list
- Insert text in braces with asyncpg
- How does one put a link / url to the web-site's home page in Django?
- How to determine if a path is a subdirectory of another?
- Custom Keybindings for Ipython terminal
- FastAPI asynchronous background tasks blocks other requests?
- How to make sure that information from one file is duplicated into several text documents, without specific lines
- Installing a Python environment with Anaconda
- sklearn pipeline model predicting same results for all input
- Brew command not found after installing Anaconda Python
- How to get an XPath from selenium webelement or from lxml?
- Pipe PuTTY console to Python script
- How to align the axes of a figure in matplotlib?
- Persist ParentDocumentRetriever of langchain
- How to reset index in a pandas dataframe?