Search code examples
tensorflowslicetensor

Change a TensorFlow tensor of shape (128, 10, 51) to a tensor of shape (128, 51) by choosing only one element in 2nd dim of the tensor


I want to change a TensorFlow tensor of shape (128, 10, 51) to a tensor of shape (128, 51) by choosing only one element in 2nd dim of the tensor. I have indicies must be choosed in a ndarray, like this:

A is a tensor of shape (128, 10, 51)

B is a ndarray of shape (128,) with elements 0 to 9

I did it by for loop, but i want a compact code to do this in one/two line/s.


Solution

  • You may try the following line:

    result = tf.gather_nd(A, tf.stack([tf.range(tf.shape(A)[0]), B], axis=1))
    

    the tf.gather_nd function gets slices from tensor A based on indices stored in tf.stack([tf.range(tf.shape(A)[0]), B].