Search code examples
pythonnumpynumbamahalanobis

Improve performance speed on batched mahalanobis distance computation


I have the following piece of code that computes mahalanobis distance over a set of batched features, on my device it takes around 100ms, most of it it's due to the matrix multiplication between delta and inv_covariance

delta is a matrix of dimension 874x32x100, inv_covariance is of dimension 874x100x100

def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
    batch, channel, height, width = embedding.shape
    embedding = embedding.reshape(batch, channel, height * width)

    # calculate mahalanobis distances
    delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))

    distances = ((delta @ inv_covariance) * delta).sum(2).transpose(1, 0)
    distances = distances.reshape(batch, 1, height, width)
    distances = np.sqrt(distances.clip(0))

    return distances

I've tried to convert the code to use numba and @njit, I've preallocated the intermediate matrix and I'm trying to perform smaller matrix multiplication using a for loop since matmul is not supported for 3 dimensional matrices.

def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
    batch, channel, height, width = embedding.shape
    embedding = embedding.reshape(batch, channel, height * width)

    # calculate mahalanobis distances
    delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
    inv_covariance = np.ascontiguousarray(inv_covariance)
    
    intermediate_matrix = np.zeros_like(delta)
    for i in range(intermediate_matrix.shape[0]):
        intermediate_matrix[i] = delta[i] @ inv_covariance[i]

    distances = (intermediate_matrix * delta).sum(2).transpose(1, 0)
    distances = np.ascontiguousarray(distances)
    distances = distances.reshape(batch, 1, height, width)
    distances = np.sqrt(distances.clip(0))

    return distances

I added a few ascontiguousarray, the last one is important or the code doesn't work, the others are put to suppress the warning saying that @ will perform faster (it doesn't seem too much).

Is there a way to make the code faster, either by improving this or rethinking it in a different mathematical way?

Edit - Final implementation

Based on Jérôme Richard answer I ended up with this code

@nb.njit()
def matmul(delta: np.ndarray, inv_covariance: np.ndarray):
    """Computes distances = ((delta[i] @ inv_covariance[i]) * delta[i]).sum(2) using numba.

    Args:
        delta: Matrix of dimension BxD
        inv_covariance: Matrix of dimension DxD

    Returns:
        Matrix of dimension BxD
    """
    si, sj, sk = delta.shape[0], inv_covariance.shape[1], delta.shape[1]
    assert sk == inv_covariance.shape[0]
    line = np.zeros(sj, dtype=delta.dtype)
    res = np.zeros(si, dtype=delta.dtype)
    for i in range(si):
        line.fill(0.0)
        for k in range(sk):
            factor = delta[i, k]
            for j in range(sj):
                line[j] += factor * inv_covariance[k, j]
        for j in range(sj):
            res[i] += line[j] * delta[i, j]
    return res


@nb.njit
def mean_subtraction(embeddings: np.ndarray, mean: np.ndarray):
    """Computes embeddings - mean using numba, this is required as I have errors with the default numpy
    implementation.

    Args:
        embeddings: Embedding matrix of dimension FxBxD
        mean: Mean matrix of dimension BxD

    Returns:
        Delta matrix of dimension FxBxD
    """
    output_matrix = np.zeros_like(embeddings)
    for i in range(embeddings.shape[0]):
        output_matrix[i] = embeddings[i] - mean

    return output_matrix


@nb.njit(parallel=True)
def compute_distance_numba(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
    """Compute distance score using numba.

    Args:
        embedding: Embedding Vector
        mean: Mean of the multivariate Gaussian distribution
        inv_covariance: Inverse Covariance matrix of the multivariate Gaussian distribution.
    """
    batch, channel, height, width = embedding.shape
    embedding = embedding.reshape(batch, channel, height * width)

    delta = np.ascontiguousarray(mean_subtraction(embedding, mean).transpose(2, 0, 1))
    inv_covariance = np.ascontiguousarray(inv_covariance)

    intermediate_matrix = np.zeros((delta.shape[0], delta.shape[1]), dtype=delta.dtype)
    for i in nb.prange(intermediate_matrix.shape[0]):
        intermediate_matrix[i] = matmul(delta[i], inv_covariance[i])

    distances = intermediate_matrix.transpose(1, 0)
    distances = np.ascontiguousarray(distances)
    distances = distances.reshape(batch, 1, height, width)
    distances = np.sqrt(distances.clip(0))

    return distances

Changes compared to the accepted answer are the custom function for subtraction and the addition of dtype for the intermediate matrix to avoid default np.float64.


Solution

  • First of all, matrix multiplications are done by libraries called BLAS and most implementation are efficient parallel ones. That being said, for batch of small matrices, the parallel implementation cannot be so efficient. Indeed, the granularity is too small so the overhead of using multiple thread becomes significant. It is better to parallelize the outer-loop and use a sequential matrix-multiplication code.

    Since the matrices involved in the matrix multiplication are pretty small, it is better to reimplement the matrix-multiplication manually. Indeed, this removes the overhead of calling functions of the matrix multiplication library (BLAS) and also ensure no threads are used during the matrix-multiplication. One need to care about reading/writing values contiguously though so the operation is SIMD-friendly.

    On top of that, the matrix-multiplication can be merged with the next line (intermediate_matrix * delta).sum(2) so to write a smaller output array and avoid reading big temporary arrays back. This is critical since the RAM is slow. This strategy also reduce the memory footprint while being faster and scale better. It is certainly a good idea to also merge the operation with the line (embedding - mean).transpose(2, 0, 1) though I did not test it.


    Implementation

    Here is an implementation considering all the point except the very last one:

    @nb.njit()
    def matmul(delta, inv_covariance):
        si, sj, sk = delta.shape[0], inv_covariance.shape[1], delta.shape[1]
        assert sk == inv_covariance.shape[0]
        line = np.zeros(sj, dtype=delta.dtype)
        res = np.zeros(si, dtype=delta.dtype)
        for i in range(si):
            line.fill(0.0)
            for k in range(sk):
                factor = delta[i, k]
                for j in range(sj):
                    line[j] += factor * inv_covariance[k, j]
            for j in range(sj):
                res[i] += line[j] * delta[i, j]
        return res
    
    @nb.njit(parallel=True)
    def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
        batch, channel, height, width = embedding.shape
        embedding = embedding.reshape(batch, channel, height * width)
    
        # calculate mahalanobis distances
        delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
        inv_covariance = np.ascontiguousarray(inv_covariance)
        
        intermediate_matrix = np.zeros((delta.shape[0], delta.shape[1]))
        for i in nb.prange(intermediate_matrix.shape[0]):
            intermediate_matrix[i] = matmul(delta[i], inv_covariance[i])
    
        distances = intermediate_matrix.transpose(1, 0)
        distances = np.ascontiguousarray(distances)
        distances = distances.reshape(batch, 1, height, width)
        distances = np.sqrt(distances.clip(0))
    
        return distances
    

    Results

    The is about 3 times faster on my i5-9600KF CPU (6-cores). Most of the time appears to be spent in the first line which can also be merged for better performance (assuming the array strides are reasonable). Note the compilation time is not included in the timings and results are equal (based on np.allclose).