Search code examples
torch

Impact of the parameter dimension in gather function


I am trying to use the gather function in pytorch but can't understand the role of dim parameter.

Code:

t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))

Output:

 1  2
 3  2
[torch.FloatTensor of size 2x2]

Dimension set to 1:

print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))

Output becomes:

 1  1
 4  3
[torch.FloatTensor of size 2x2]

How, gather function actually works?


Solution

  • I realized how the gather function works.

    t = torch.Tensor([[1,2],[3,4]])
    index = torch.LongTensor([[0,0],[1,0]])
    torch.gather(t, 0, index)
    

    Since the dimension is zero, so the output will be:

    | t[index[0, 0], 0]   t[index[0, 1], 1] |
    | t[index[1, 0], 0]   t[index[1, 1], 1] |
    

    If the dimension is set to one, the output will become:

    | t[0, index[0, 0]]   t[0, index[0, 1]] |
    | t[1, index[1, 0]]   t[1, index[1, 1]] |
    

    So the formula is:

    For a 3-D tensor the output is specified by:
    
    out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
    out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
    out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
    

    Reference: http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather