Search code examples
pythonpytorchtensor

pytorch multi hot vectors


I want to implement a multi-hot vector in PyTorch.

  1. create a zero tensor of size len(x) x (multi_hot_num * max_num)
  2. Then, for each element in 'x', fill the corresponding range of indices from x[i] * multi_hot_num to (x[i]+1) * multi_hot_num in the tensor with 1.

The following code succinctly demonstrates the desired behavior:

max_num = 4
multi_hot_num = 3
x = torch.tensor([0, 2, 1])

expect output:

tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 ,0, 0]])

So my question is, how to create an expected output using x, max_num, and multi_hot_num.


Solution

  • I try to find an optimized way without loop, but it seems like using loop is the only solution (to the best of my knowledge).

    import torch
    
    max_num = 4
    multi_hot_num = 3
    x = torch.tensor([0, 2, 1])
    
    assert torch.all(x < max_num)
    
    result = torch.zeros(len(x), max_num*multi_hot_num)
    for i, (start, end) in enumerate(zip(x*multi_hot_num, (x+1)*multi_hot_num)):
        result[i, start:end] = 1.
        
    print(result)
    

    Result:

    tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0.],
            [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.]])
    

    EDITED: without for loop

    Solution based on this answer

    import torch
    
    max_num = 4
    multi_hot_num = 3
    x = torch.tensor([0, 2, 1])
    
    assert torch.all(x < max_num)
    
    start = (x * multi_hot_num).unsqueeze(-1)
    end = ((x+1) * multi_hot_num).unsqueeze(-1)
    
    result = torch.zeros(len(x), max_num*multi_hot_num)
    index = torch.arange(max_num*multi_hot_num).unsqueeze(0).repeat(len(x), 1)
    
    gte_start = start <= index
    lt_end = index < end
    mask = gte_start & lt_end
    result[mask]=1
    
    print(result)