Assume I got the following tensor:
arr = torch.randint(0, 9, (100, 50, 3))
What I want to achieve is collecting, for example, 2 elements of that tensor, let's start with collecting the 6th and 56th one:
indices = torch.tensor([5, 55])
partial_arr = arr[indices]
This gives me an array of shape
torch.Size([2, 50, 3])
Now, let's assume that from the first element, I want to collect the elements 5 through 10
first_result = partial_arr[0, 5:10]
and from the second element, the elements from 10 to 15:
second_result = partial_arr[1, 10:15]
Since I want everything in one tensor, I can do:
final_result = torch.cat([first_result, second_result])
How can I achieve the final result only with one operation on the first tensor: arr = torch.randint(0, 9, (100, 50, 3))
?
Assuming the number of sliced elements remains constant across rows, you can create an arrangement tensor and shift it by the per-row starting index:
>>> idx = torch.tensor([5,10])
>>> idx_ = torch.arange(5,)[None]+idx[:,None]
tensor([[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
Then expand idx_
such that it has the same last dimension size as partial_arr
:
>>> idx_ = idx_[...,None].expand(-1,-1,partial_arr.size(-1))
# shaped torch.Size([2, 5, 3])
Finally, gather the values using torch.gather
:
>>> partial_arr.gather(1,idx_).shape
tensor([[[8, 3, 1],
[2, 4, 6],
[4, 4, 5],
[2, 8, 6],
[3, 7, 0]],
[[3, 6, 7],
[5, 7, 4],
[1, 5, 4],
[4, 5, 3],
[7, 1, 2]]])