Search code examples
pythonarraysnumpysimilarity

Perform group operation on 2D numpy array


I have a 2D numpy array (in fact a similarity matrix) on which I need to compute average by blocks. For instance with the following matrix:

sima = np.array([[1,0.8,0.7,0.3,0.1,0.5],
                 [0.8,1,0.1,0.5,0.2,0.5],
                 [0.7,0.1,1,0.1,0.3,0.9],
                 [0.3,0.5,0.1,1,0.8,0.5],
                 [0.1,0.2,0.3,0.8,1,0.5],
                 [0.5,0.5,0.9,0.5,0.5,1]])

And labels vector :

labels = np.array([1,1,1,2,2,3])

This means that the first three rows of the matrix (as well as columns columns since a similarity matrix is symmetric) correspond to the cluster 1, the next 2 correspond to the cluster 2, and the last correspond to the cluster 3.

I need to compute the average of the blocks in sima correpsonding to the labels in labels. Yielding the following output:

0.69 0.25 0.63 
0.25 0.90 0.50 
0.63 0.50 1.00

So far, I have a working solution using a double loop on labels and masked arrays:

labels_matrix = np.tile(np.array(labels), (len(labels), 1))
output = pd.DataFrame(np.zeros(shape = (3,3)))

for i in range(3):
  for j in range(3):
    mask = (labels_matrix != j+1) | (labels_matrix.T != i+1)
    output.loc[i,j] = np.mean(np.mean(np.ma.array(sima, mask = mask)))

This code yields the correct output, but my actual matrix is 50kx50k, and this code takes forever to compute. How could I make it faster?

Note: I need a different order of magnitude in speed, so I expect using tricks like the symmetry of the similarity matrix won't be enough.


Solution

  • For sorted labels, we can use np.add.reduceat -

    In [62]: idx = np.flatnonzero(np.r_[True,labels[:-1] != labels[1:],True])
    
    In [63]: c = np.diff(idx)
    
    In [64]: sums = np.add.reduceat(np.add.reduceat(sima,idx[:-1],axis=0),idx[:-1],axis=1)
    
    In [65]: sums/(c[:,None]*c)
    Out[65]: 
    array([[0.68888889, 0.25      , 0.63333333],
           [0.25      , 0.9       , 0.5       ],
           [0.63333333, 0.5       , 1.        ]])