I want to implement a multi-hot vector in PyTorch.
The following code succinctly demonstrates the desired behavior:
max_num = 4
multi_hot_num = 3
x = torch.tensor([0, 2, 1])
expect output:
tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 ,0, 0]])
So my question is, how to create an expected output using x, max_num, and multi_hot_num.
I try to find an optimized way without loop, but it seems like using loop is the only solution (to the best of my knowledge).
import torch
max_num = 4
multi_hot_num = 3
x = torch.tensor([0, 2, 1])
assert torch.all(x < max_num)
result = torch.zeros(len(x), max_num*multi_hot_num)
for i, (start, end) in enumerate(zip(x*multi_hot_num, (x+1)*multi_hot_num)):
result[i, start:end] = 1.
print(result)
Result:
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0.],
[0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.]])
Solution based on this answer
import torch
max_num = 4
multi_hot_num = 3
x = torch.tensor([0, 2, 1])
assert torch.all(x < max_num)
start = (x * multi_hot_num).unsqueeze(-1)
end = ((x+1) * multi_hot_num).unsqueeze(-1)
result = torch.zeros(len(x), max_num*multi_hot_num)
index = torch.arange(max_num*multi_hot_num).unsqueeze(0).repeat(len(x), 1)
gte_start = start <= index
lt_end = index < end
mask = gte_start & lt_end
result[mask]=1
print(result)