Search code examples
tensorflowindexingsliceaxistensorflow2.x

TensorFlow - Index by axis


I want to index into the last axis of a tensor with an arbitrary shape, except for the last which is 2.

e.g. Let x be of the shape (1,2,2). Index to the last axis by

x_0 = x[:, :, 0]    # x_0, x_1 shapes are (1,2)
x_1 = x[:, :, 1]

e.g. Let x be of the shape (1,2,3,4,2). Index to the last axis by

x_0 = x[:, :, :, :, 0]   # x_0, x_1 shapes are (1,2,3,4)
x_1 = x[:, :, :, :, 1]

I've been unable to find any tensorflow function or usage for slicing an arbitrary shape.

I need a general method to index, such that I can always access the last axis for any shape tensor.


Solution

  • The slice syntax in tensorflow is very similar to . You can use the ellipsis in that case:

    Ellipsis expands to the number of : objects needed for the selection tuple to index all dimensions. In most cases, this means that length of the expanded selection tuple is x.ndim. There may only be a single ellipsis present.

    In your case,

    x_0 = x[..., 0]
    

    will index the last axis of a tensor with an arbitrary shape.

    You can also look at the answer to the question: What is the difference between the slice (:) and the ellipsis (…) operators in numpy?.