Search code examples
indexingtorchone-hot-encoding

In Torch how do I create a 1-hot tensor from a list of integer labels?


I have a byte tensor of integer class labels, e.g. from the MNIST data set.

 1
 7
 5
[torch.ByteTensor of size 3]

How do use it to create a tensor of 1-hot vectors?

 1  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  1  0  0  0
 0  0  0  0  1  0  0  0  0  0
[torch.DoubleTensor of size 3x10]

I know I could do this with a loop, but I'm wondering if there's any clever Torch indexing that will get it for me in a single line.


Solution

  • indices = torch.LongTensor{1,7,5}:view(-1,1)
    one_hot = torch.zeros(3, 10)
    one_hot:scatter(2, indices, 1)
    

    You can find the documentation for scatter in the torch/torch7 github readme (in the master branch).