Search code examples
pythontensorflowslicez-indextensor

How to efficiently get values of each row in a tensor using indices?


I have a tensor called my_tensor with tha shape of [batch_size, seq_length] and I have another tensor named idx with tha shape of [batch_size, 1] which is comprised of indices which start at 0 and finish at "seq_length".

I want to extract the values of in each row of my_tensor on using the indices defined in idx.

I tried to use tf.gather_nd and tf.gather but I was not successful.

Consider the following example:

batch_size = 3
seq_length = 5
idx = [2, 0, 4]

my_tensor = tf.random.uniform(shape=(batch_size, seq_length))

I want to get the values at

[[0, 2],
 [1, 0],
 [3, 4]]

from my_tensor.

I have to do further process over them, so I would like to have them at the same time (I don't know if it is even possible) and in an efficient way; however, I could not come up with any other methods.

I appreciate any help :)


Solution

  • The trick is to first convert your set of indices into a boolean mask which you can then use to reduce my_tensor as you have described using the boolean_mask operation.

    You can accomplish this by one-hot encoding the idx tensor.

    So, where idx = [2, 0, 4] we can do tf.one_hot(idx, seq_length) in order to convert it to something like this:

    [ [0., 0., 1., 0., 0.],
      [1., 0., 0., 0., 0.],
      [0., 0., 0., 0., 1.] ]
    

    Then, putting it all together for, say my_tensor:

    [ [0.6413697 , 0.4079175 , 0.42499018, 0.3037368 , 0.8580252 ],
      [0.8698617 , 0.29096508, 0.11531639, 0.25421357, 0.5844104 ],
      [0.6442119 , 0.31816053, 0.6245482 , 0.7249261 , 0.7595779 ] ]
    
    

    we can proceed as follows:

    result = tf.boolean_mask(my_tensor, tf.one_hot(idx,seq_length))
    

    to give:

    [0.42499018, 0.8698617 , 0.7595779 ]
    

    as expected