Search code examples
pythonpytorchtensor

Get final values from a specific dimension/axis of an arbitrarily dimensioned PyTorch Tensor


Suppose I had a PyTorch tensor such as:

import torch

x = torch.randn([3, 4, 5])

and I wanted to get a new tensor, with the same number of dimensions, containing everything from the final value of dimension 1. I could just do:

x[:, -1:, :]

However, if x had an arbitrary number of dimensions, and I wanted to get the final values from a specific dimension, what is the best way to do it?


Solution

  • You can use index_select:

    torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1))
    

    The output tensor would contain the same number of dimensions as the input. You can use squeeze on the dim to get rid of the extra dimension:

    torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1)).squeeze(dim=dim)
    

    Note: While select returns a view of the input tensor, index_select returns a new tensor.

    Example:

    In [1]: dim = 1
    
    In [2]: x = torch.randn([3, 4, 5])
    
    In [3]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).shape
    Out[3]: torch.Size([3, 1, 5])
    
    In [4]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).squeeze(dim=dim).shape
    Out[4]: torch.Size([3, 5])