Search code examples
pytorch

Gather different pixels per image of an image stack with torch


I have a batch of images and a batch of indices (x, y) for each image. The indices are different for each image, so I cant use simple indexing. What is the best or fastest way to get another batch with the colors of the selected pixels per image?

    n_images = 4
    width = 100
    height = 100
    channels = 3
    n_samples = 30
    
    images = torch.rand((n_images, height, width, channels))
    indices = (torch.rand((n_images, n_samples, 2)) * width).to(torch.int32)

    # preferred function
    # result = images[indices]
    # with result.shape = (n_images, n_samples, 3)



    
    # I just found this solution but I would rather like to call a general torch function
    xs = indices.reshape((-1, 2))[:, 0]
    ys = indices.reshape((-1, 2))[:, 1]
    ix = torch.arange(n_images, dtype=torch.int32)
    ix = ix[..., None].expand((-1, n_samples)).flatten()
    
    result = images[ix, ys, xs].reshape((n_images, n_samples, 3))

Solution

  • You can use your indices tensor directly, you just need another tensor for the batch indexing:

    n_images = 4
    width = 100
    height = 100
    channels = 3
    n_samples = 30
    
    images = torch.rand((n_images, height, width, channels))
    indices = (torch.rand((n_images, n_samples, 2)) * width).to(torch.int32)
    
    batch_indices = torch.arange(n_images).view(-1, 1).expand(-1, n_samples)
    result = images[batch_indices, indices[..., 1], indices[..., 0]]
    

    This follows your convention of images[ix, ys, xs] where the ys index the height dimension of the tensor and the xs index the width