Search code examples
numpynumpy-ndarraynumpy-slicing

Index a batch of numpy vectors with a batch of numpy index matrices


If I have a vector vec, I can index it with a matrix as follows:

import numpy as np

vec = np.asarray([1,2,3,4]) # Shape (4,)

mat = np.asarray([[0,2],
                  [3,1]]) # Shape (2,2)

result = vec[mat] # Shape (2,2)


# result = array([[1, 3],
#              [4, 2]])

Now suppose I have batches of vectors and matrices instead. What would be the simplest way to do the same kind of indexing for each element in the batch, strictly using numpy? For instance:

vecs = np.asarray([[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9,10,11,12]]) # Shape (3,4)

mats = np.asarray([ [[0,1],
                     [1,0]],
                    [[0,2],
                     [1,1]],
                    [[3,1],
                     [0,0]] ]) # Shape (3,2,2)

# results = np.asarray([ [[1,2],
#                         [2,1]],
#                        [[5,7],
#                         [6,6]],
#                        [[12,10],
#                         [9,9]] ]) # Shape (3,2,2)

What about higher-dimensional batches? Is there a simple and general way to vectorize such indexing to arbitrary batches?


Solution

  • Add a (3,1,1) array to index the first dimension:

    In [196]: vecs[np.arange(3)[:,None,None],mats]
    Out[196]: 
    array([[[ 1,  2],
            [ 2,  1]],
    
           [[ 5,  7],
            [ 6,  6]],
    
           [[12, 10],
            [ 9,  9]]])
    

    This will broadcast with your (3,2,2) mats.