Search code examples
pythonpytorchtensorargmax

Convert one-hot encoded dimension into the index of position of 1


I have a tensor of three dimensions [batch_size, sequence_length, number_of_tokens]. The last dimension is one-hot encoded. I want to receive a tensor of two dimensions, where sequence_length consists of the index position of '1' of the number_of_tokens dimension.

For example, to turn a tensor of shape (2, 3, 4):

[[[0, 1, 0, 0]
[1, 0, 0, 0]
[0, 0, 0, 1]]
[[1, 0, 0, 0]
[1, 0, 0, 0]
[0, 0, 1, 0]]]

into a tensor of shape (2, 3) where number_of_tokens dimension is converted into the 1's position:

[[1, 0, 3]
[0, 0, 2]]

I'm doing it to prepare the model result to compare to reference answer when computing loss, I hope it is correct way.


Solution

  • If your original tensor is as specified in your previous question, you can bypass the one-hot encoding and directly use the argmax:

    t = torch.rand(2, 3, 4)
    t = t.argmax(dim=2)