I am trying to use the gather function in pytorch but can't understand the role of dim
parameter.
Code:
t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))
Output:
1 2
3 2
[torch.FloatTensor of size 2x2]
Dimension set to 1:
print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))
Output becomes:
1 1
4 3
[torch.FloatTensor of size 2x2]
How, gather
function actually works?
I realized how the gather function works.
t = torch.Tensor([[1,2],[3,4]])
index = torch.LongTensor([[0,0],[1,0]])
torch.gather(t, 0, index)
Since the dimension
is zero, so the output will be:
| t[index[0, 0], 0] t[index[0, 1], 1] |
| t[index[1, 0], 0] t[index[1, 1], 1] |
If the dimension
is set to one, the output will become:
| t[0, index[0, 0]] t[0, index[0, 1]] |
| t[1, index[1, 0]] t[1, index[1, 1]] |
So the formula is:
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Reference: http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather