I want to slice a 3D tensor in PyTorch. The shape of the 3D tensor src_tensor is (batch, max_len, hidden_dim), and I have a 1D index vector indices with the shape (batch,). I want to slice along the second dimension of src_tensor. I can achieve this functionality with the following codes:
import torch
nums = 30
l = [i for i in range(nums)]
src_tensor = torch.Tensor(l).reshape((3,5,2))
indices = [1,2,3]
slice_tensor = torch.zeros((3,2,2))
for i in range(3):
p1,p2 = indices[i],indices[i]+1
slice_tensor[i,:,:]=src_tensor[i,[p1,p2],:]
print(src_tensor)
print(indices)
print(slice_tensor)
"""
tensor([[[ 0., 1.],
[ 2., 3.],
[ 4., 5.],
[ 6., 7.],
[ 8., 9.]],
[[10., 11.],
[12., 13.],
[14., 15.],
[16., 17.],
[18., 19.]],
[[20., 21.],
[22., 23.],
[24., 25.],
[26., 27.],
[28., 29.]]])
[1, 2, 3]
tensor([[[ 2., 3.],
[ 4., 5.]],
[[14., 15.],
[16., 17.]],
[[26., 27.],
[28., 29.]]])
"""
My question is whether the above code can be simplified, for example, by eliminating the for loop.
Since your indexing along the second dimension depends on the batch index, I think you need to leverage some weird more complex indexing schemes :
import torch
# arbitrary source tensor
src_tensor = torch.Tensor(range(30)).view((3,5,2))
B, N, D = src_tensor.shape
# arbitrary, non consecutive indices made into an index tensor for torch.gather
ids = [0, 1, 3]
index = torch.tensor(ids)
# slicing
sliced = torch.empty(B, 2, D) #empty instead of zero to avoid writing useless values
sliced[:, 0, :] = src_tensor[range(B), ids]
sliced[:, 1, :] = src_tensor[:,1:,:][range(B), ids]
The first slicing on the last line serves to offset the tensor by 1 in the second dimension, which makes it possible to access the index+1
values without doing any addition or copying.
Note that if you know for sure that your indices are all consecutives (like in your example), then using narrow
instead of slicing will be even better. Alternatively, for some even more complex indexing schemes, you can also try torch.gather
Note : as far as I'm aware, this slicing scheme is consistent with numpy's, so most of the documentation you can find for numpy slicing applies to torch