Search code examples
pythonpytorchtensortorch

Use torch.gather to select images from tensor


I have a tensor of images of size (3600, 32, 32, 3) and I have a multi hot tensor [0, 1, 1, 0, ...] of size (3600, 1). I am looking to basically selecting images that correspond to a 1 in the multi hot tensor. I am trying to understand how to use torch.gather:

tensorA.gather(0, tensorB)

Gives me issues with dims and I can't properly understand how to reshape them.


Solution

  • When using torch.gather, the dimension of input and dimension of index must be the same. And the index is not a multi hot tensor, but the location of the desired value.

    You can slice the tensor by using the index of the multi hot tensor. The fourth line finds the index with a value of 1 in the multi hot tensor. The fifth line slices the image based on the index.

    Code:

    tensorA = torch.randn(4, 32, 32, 3)
    tensorB = torch.tensor([0,1,1,0])
    
    tensorB_where = torch.where(tensorB == 1)[0]
    result = tensorA[tensorB_where]