Search code examples
pythonindexingpytorchmulti-index

pytorch batchwise indexing


I am searching for a way to do some batchwise indexing for tensors.

If I have a variable Q of size 1000, I can get the elements I want by Q[index], where index is a vector of the wanted elements.

Now I would like to do the same for more dimensional tensors. So suppose Q is of shape n x m and I have a index matrix of shape n x p.

My goal is to get for each of the n rows the specific p elements out of the m elements.

But Q[index] is not working for this situation.

Do you have any thoughts how to handle this?


Solution

  • You can seem to be a simple application of torch.gather which doesn't require any additional reshaping of the data or index tensor:

    >>> Q = torch.rand(5, 4)
    tensor([[0.8462, 0.3064, 0.2549, 0.2149],
            [0.6801, 0.5483, 0.5522, 0.6852],
            [0.1587, 0.4144, 0.8843, 0.6108],
            [0.5265, 0.8269, 0.8417, 0.6623],
            [0.8549, 0.6437, 0.4282, 0.2792]])
    
    >>> index
    tensor([[0, 1, 2],
            [2, 3, 1],
            [0, 1, 2],
            [2, 2, 2],
            [1, 1, 2]])
    

    The following gather operation applied on dim=1 return a tensor out, such that:

    out[i, j] = Q[i, index[i,j]]
    

    This is done with the following call of torch.Tensor.gather on Q:

    >>> Q.gather(dim=1, index=index)
    tensor([[0.8462, 0.3064, 0.2549],
            [0.5522, 0.6852, 0.5483],
            [0.1587, 0.4144, 0.8843],
            [0.8417, 0.8417, 0.8417],
            [0.6437, 0.6437, 0.4282]])