Search code examples
pythonalgorithmoptimization

How to remove redundancy when computing sums for many rings


I have this code to compute the sum of the values in a matrix that are closer than some distance but further away than another. Here is the code with some example data:

square = [[ 3,  0,  1,  3, -1,  1,  1,  3, -2, -1],
   [ 3, -1, -1,  1,  0, -1,  2,  1, -2,  0],
   [ 2,  2, -2,  0,  1, -3,  0, -2,  2,  1],
   [ 0, -3, -3, -1, -1,  3, -2,  0,  0,  3],
   [ 2,  2,  3,  2, -1,  0,  3,  0, -3, -1],
   [ 1, -1,  3,  1, -3,  3, -2,  0, -3,  0],
   [ 2, -2, -2, -3, -2,  1, -2,  0,  0,  3],
   [ 0,  3,  0,  1,  3, -1,  2, -3,  0, -2],
   [ 0, -2,  2,  2,  2, -2,  0,  2,  1,  3],
   [-2, -2,  0, -2, -2,  2,  0,  2,  3,  3]]

def enumerate_matrix(matrix):
    """
    Enumerate the elements in the matrix.
    """
    for x, row in enumerate(matrix):
        for y, value in enumerate(row):
            yield x, y, value

def sum_of_values(matrix, d):
    """
    Calculate the sum of values based on specified conditions.
    """
    total_sum = 0
    for x, y, v in enumerate_matrix(matrix):
        U = x * x + x + y * y + y + 1
        if d * d * 2 < U < (d + 1) ** 2 * 2:
            total_sum += v
    return total_sum

For this case, I want to compute sum_of_values(square, x) for x in [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]. This is fast enough but I also want to do it for much larger matrices and the code is then doing a lot of redundant computation. How can I remove this redundancy?

For example:

import numpy as np
square = np.random.randint(-3, 4, size=(1000, 1000))
for i in range(1000):
    result = sum_of_values(square, i + 0.5)
    print(f"Sum of values: {result} {i}")

This is too slow as I will need to perform this calculation for thousands of different matrices. How can the redundant calculations in my code be removed?

The key problem I think is that enunerate_matrix should only be looking at cells in the matrix that are likely to be the right distance instead of repeatedly rechecking all the cells in the matrix .


Timings

For a 400 by 400 matrix my code takes approx 26 seconds.

def calc_values(matrix, n):
    scores = []
    for i in tqdm(range(n)):
        result = sum_of_values(square, i + 0.5)
        scores.append(result)
    return scores

n = 400
square = np.random.randint(-3, 4, size=(n, n))
%timeit calc_values(square, n)
  • RomanPerekhrest's code takes approx 119ms even including making the U_arrays matrix.
  • Reinderien's code takes approx 149ms.

Solution

  • Traverse the input square matrix just once to generate an array of pairs where calculated U parameter mapped to the respective value.
    Then apply a vectorized operation to sum up values filtered by U params matched the condition.

    def make_U_array(mtx):
        """Make an array of (U, value) pairs"""
        arr = np.array([(x * x + x + y * y + y + 1, value)
                        for x, row in enumerate(mtx)
                        for y, value in enumerate(row)])
        return arr
    
    def sum_values_by_cond(U_values, d):
        # mask values where U parameters fit the condition
        m = (d * d * 2 < U_values[:, 0]) & (U_values[:, 0] < (d + 1) ** 2 * 2)
        return np.sum(U_values[:, 1][m])
    

    Update: alternative and faster version of make_U_array function based on np.indices (to get row/column indices), it should give about 3x time speedup compared to a previous list-comprehension approach:

    def make_U_array(mtx):
        """Make an array of (U, value) pairs"""
        x, y = np.indices(mtx.shape)
        x, y = x.flatten(), y.flatten() # row/column indices
        arr = np.column_stack((x * x + x + y * y + y + 1, np.ravel(mtx)))
        return arr
    

    Sample case (assuming you initial square array):

    U_arr = make_U_array(square)
    
    for i in range(10):
        result = sum_values_by_cond(U_arr, i + 0.5)
        print(f"Sum of values: {result} {i}")
    

    Sum of values: 6 0
    Sum of values: 3 1
    Sum of values: -1 2
    Sum of values: 4 3
    Sum of values: 3 4
    Sum of values: 3 5
    Sum of values: -11 6
    Sum of values: 4 7
    Sum of values: 7 8
    Sum of values: 3 9