Search code examples
pythonpytorchone-hot-encodingmulticlass-classification

How to convert one-hot vector to label index and back in Pytorch?


How to transform vectors of labels to one-hot encoding and back in Pytorch?

The solution to the question was copied to here after having to go through the entire forum discussion, instead of just finding an easy one from googling.


Solution

  • From the Pytorch forums

    import torch
    import numpy as np
    
    
    labels = torch.randint(0, 10, (10,))
    
    # labels --> one-hot 
    one_hot = torch.nn.functional.one_hot(labels)
    # one-hot --> labels
    labels_again = torch.argmax(one_hot, dim=1)
    
    np.testing.assert_equals(labels.numpy(), labels_again.numpy())