Search code examples
pythonnumpytheano

How to use integer values of a matrix as index for another matrix using numpy or Theano?


I have the following 4 matrices of the same shapes: (1) a matrix I containing integer values, (2) a matrix J containing integer values, (3) a matrix D containing float values and (4) matrix V containing float values.

I want to use these 4 matrices to construct an "output" matrix in the following way:

  1. To find a value for element i, j of the output matrix, find all the cells (elements) of the matrix I that are equal to i and all cells (elements) of the matrix J that are equal to j.
  2. Use only those cells that satisfy both conditions (remember that matrix I and J have the same shapes).
  3. Search among the "selected" cells the one that has the smallest value of D.
  4. Take the found cell (with the smallest value of D) and check what value it has in the matrix V.

In this way we find the value for i, j element of the output matrix. I do it for all is and js.

I would like to solve this problem using either numpy or Theano.

Of course I could loop over all i_s and j_s but I think (hope) that there should be a more efficient way.

ADDED

As requested, I provide an example:

Here is the matrix I:

 0   1   2
 1   1   0
 0   0   2

Here is the matrix J:

 1   1   1
 1   2   1
 0   1   0

Here is the matrix D:

 1.2   3.4   2.2
 2.2   4.3   2.3
 7.1   6.1   2.7

And finally we have the matrix V:

 1.1   8.1   9.1
 3.1   7.1   2.1
 0.1   5.1   3.1

As you can see, all 4 matrices have the same shape (3 x 4) but they can have another shape (for example 2 x 5). The main thing is that the shape is the same for all 4 matrices.

As we can see the values of the matrix I are from 0 to 2 so, the output matrix should have 3 rows. In the same way, we can conclude that the output matrix should have 3 columns (because values of the matrix J are also from 0 to 2).

Let us first find the element (0, 1) of the output matrix. In the I matrix the following cells (marked by x) contain 0.

 x   .   .
 .   .   x
 x   x   .

In the matrix J the following elements contains 1:

 x   x   x
 x   .   x
 .   x   .

An intersection of these two sets of cells is:

 x   .   .
 .   .   x
 .   x   .

The corresponding distances are:

 1.2    .     . 
  .     .    2.3
  .    6.1    . 

So, the smallest distance is located in the left top corner. As a result we take the value from the left top corner of the matrix V (this value is 1.1).

This is how we found the value of (0,1) element of the output matrix. We do the same procedure for all possible combination of indexes (in total we have 3 x 3 = 9) combinations. For some combination we cannot find any value, in this case we set value equal to nan.


Solution

  • Here's a vectorized approach using broadcasting -

    # Get mask of matching elements against the iterators
    m,n = I.shape
    Imask = I == np.arange(m)[:,None,None,None]
    Jmask = J == np.arange(n)[:,None,None]
    
    # Get the mask of intersecting ones
    mask = Imask & Jmask
    
    # Get D intersection masked array
    Dvals = np.where(mask,D,np.inf)
    
    # Get argmin along merged last two axes. Index into flattened V for final o/p
    out = V.ravel()[Dvals.reshape(m,n,-1).argmin(-1)]
    

    Sample input, output -

    In [136]: I = np.array([[0,1,2],[1,1,0],[0,0,2]])
         ...: J = np.array([[1,1,1],[1,2,1],[0,1,0]])
         ...: D = np.array([[1.2, 3.4, 2.2],[2.2, 4.3, 2.3],[7.1, 6.1, 2.7]])
         ...: V = np.array([[1.1 , 8.1, 9.1],[3.1, 7.1, 2.1],[0.1, 5.1, 3.1]])
         ...: 
    
    In [144]: out
    Out[144]: 
    array([[ 0.1,  1.1,  1.1], # To verify : v[0,1] = 1.1
           [ 1.1,  3.1,  7.1],
           [ 3.1,  9.1,  1.1]])