Search code examples
pythonnumpymatrixtheano

Indexing tensor with index matrix in theano?


I have a theano tensor A such that A.shape = (40, 20, 5) and a theano matrix B such that B.shape = (40, 20). Is there a one-line operation I can perform to get a matrix C, where C.shape = (40, 20) and C(i,j) = A[i, j, B[i,j]] with theano syntax?

Essentially, I want to use B as an indexing matrix; what is the most efficient/elegant to do this using theano?


Solution

  • You can do the following in numpy:

    import numpy as np
    
    A = np.arange(4 * 2 * 5).reshape(4, 2, 5)
    B = np.arange(4 * 2).reshape(4, 2) % 5
    
    C = A[np.arange(A.shape[0])[:, np.newaxis], np.arange(A.shape[1]), B]
    

    So you can do the same thing in theano:

    import theano
    import theano.tensor as T
    
    AA = T.tensor3()
    BB = T.imatrix()
    
    CC = AA[T.arange(AA.shape[0]).reshape((-1, 1)), T.arange(AA.shape[1]), BB]
    
    f = theano.function([AA, BB], CC)
    
    f(A.astype(theano.config.floatX), B)