I have 4 PyTorch tensors:
data
of shape (l, m, n)
a
of shape (k,)
and datatype long
b
of shape (k,)
and datatype long
c
of shape (k,)
and datatype long
I want to slice the tensor data
such that it picks the element addressed by a
in 0th
dimension. In the 1st
and 2nd
dimensions, I want to pick a patch of values based on the element addressed by b
and c
. Specifically, I want to pick 9 values - a 3x3
patch around the value addressed by b
. Thus my sliced tensor should have a shape (k, 3, 3)
.
MWE:
data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()
data1 = data[a, b-1:b+1, c-1:c+1] # gives error
>>> TypeError: only integer tensors of a single element can be converted to an index
Expected output
data1[0] = [[143,144,145],[153,154,155],[163,164,165]]
data1[1] = [[52,53,54],[62,63,64],[72,73,74]]
data1[2] = [[126,127,128],[136,137,138],[146,147,148]]
and so on
How can I do this without using for loop?
PS:
data
to make sure that the locations addressed by a,b,c
are within the limit.I would first expand the indices and then add shifts to the repeated indices. Note that the shift for the row and column should be reversed. For example,
import torch
data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()
index1 = a.repeat_interleave(9) # kernel_size^2
index2 = b.repeat_interleave(9) # kernel_size^2
shift = torch.arange(-1, 2).repeat_interleave(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, -1, -1, 0, 0, 0, 1, 1, 1]
shifted_index2 = index2 + shift
index3 = c.repeat_interleave(9)
shift = torch.arange(-1, 2).repeat(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, 0, 1, -1, 0, 1, -1, 0, 1]
shifted_index3 = index3 + shift
# Use the indexing arrays to select the patches
data1 = data[index1, shifted_index2, shifted_index3].view(5, 3, 3)
print(data1[0])
print(data1[1])
print(data1[2])
The output:
tensor([[143, 144, 145],
[153, 154, 155],
[163, 164, 165]])
tensor([[52, 53, 54],
[62, 63, 64],
[72, 73, 74]])
tensor([[126, 127, 128],
[136, 137, 138],
[146, 147, 148]])