Search code examples
pythonpytorchvectorizationslicenumpy-slicing

Slice a multidimensional pytorch tensor based on values in other tensors


I have 4 PyTorch tensors:

  • data of shape (l, m, n)
  • a of shape (k,) and datatype long
  • b of shape (k,) and datatype long
  • c of shape (k,) and datatype long

I want to slice the tensor data such that it picks the element addressed by a in 0th dimension. In the 1st and 2nd dimensions, I want to pick a patch of values based on the element addressed by b and c. Specifically, I want to pick 9 values - a 3x3 patch around the value addressed by b. Thus my sliced tensor should have a shape (k, 3, 3).

MWE:

data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()

data1 = data[a, b-1:b+1, c-1:c+1]  # gives error

>>> TypeError: only integer tensors of a single element can be converted to an index

Expected output

data1[0] = [[143,144,145],[153,154,155],[163,164,165]]
data1[1] = [[52,53,54],[62,63,64],[72,73,74]]
data1[2] = [[126,127,128],[136,137,138],[146,147,148]]
and so on

How can I do this without using for loop?

PS:

  • I've padded data to make sure that the locations addressed by a,b,c are within the limit.
  • I don't need gradients to flow through this operation. So, I can convert these to NumPy and slice if that is faster. But I would prefer a solution in PyTorch.

Solution

  • I would first expand the indices and then add shifts to the repeated indices. Note that the shift for the row and column should be reversed. For example,

    import torch
    
    data = torch.arange(200).reshape((2, 10, 10))
    a = torch.Tensor([1, 0, 1, 1, 0]).long()
    b = torch.Tensor([5, 6, 3, 4, 7]).long()
    c = torch.Tensor([4, 3, 7, 6, 5]).long()
    
    index1 = a.repeat_interleave(9) # kernel_size^2
    
    index2 = b.repeat_interleave(9) # kernel_size^2
    shift = torch.arange(-1, 2).repeat_interleave(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, -1, -1,  0,  0,  0,  1,  1,  1]
    shifted_index2 = index2 + shift
    
    index3 = c.repeat_interleave(9) 
    shift = torch.arange(-1, 2).repeat(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1,  0,  1, -1,  0,  1, -1,  0,  1]
    shifted_index3 = index3 + shift
    
    # Use the indexing arrays to select the patches
    data1 = data[index1, shifted_index2, shifted_index3].view(5, 3, 3)
    
    print(data1[0])
    print(data1[1])
    print(data1[2])
    

    The output:

    tensor([[143, 144, 145],
            [153, 154, 155],
            [163, 164, 165]])
    tensor([[52, 53, 54],
            [62, 63, 64],
            [72, 73, 74]])
    tensor([[126, 127, 128],
            [136, 137, 138],
            [146, 147, 148]])