Search code examples
pythondeep-learningpytorchtorchtensor

How to dynamically index the tensor in pytorch?


For example, I got a tensor:

tensor = torch.rand(12, 512, 768)

And I got an index list, say it is:

[0,2,3,400,5,32,7,8,321,107,100,511]

I wish to select 1 element out of 512 elements on dimension 2 given the index list. And then the tensor's size would become (12, 1, 768).

Is there a way to do it?


Solution

  • There is also a way just using PyTorch and avoiding the loop using indexing and torch.split:

    tensor = torch.rand(12, 512, 768)
    
    # create tensor with idx
    idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
    # convert list to tensor
    idx_tensor = torch.tensor(idx_list) 
    
    # indexing and splitting
    list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)
    

    When you call tensor[:, idx_tensor, :] you will get a tensor of shape:
    (12, len_of_idx_list, 768).
    Where the second dimension depends on your number of indices.

    Using torch.split this tensor is split into a list of tensors of shape: (12, 1, 768).

    So finally list_of_tensors contains tensors of the shape:

    [torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768]),
     torch.Size([12, 1, 768])]