Suppose i have a tensor t
consisting only zeros and ones:
t = torch.Tensor([1, 0, 0, 1])
And a dict with the weights
:
weights = {0: 0.1, 1: 0.9}
I want to form a new tensor new_t
, such that every element in tensor t
is mapped to the corresponding value in the dict weights
:
new_t = torch.Tensor([0.9, 0.1, 0.1, 0.9])
Is there an elegant way to do this without iterating over tensor t
? I've heard about torch.apply
, but it only works if tensor t
is on the CPU, is there any other options?
If you convert your weights dict into a tensor, you can index directly
t = torch.tensor([1, 0, 0, 1])
weights = torch.tensor([0.1, 0.9])
new_t = weights[t]
new_t
>tensor([0.9000, 0.1000, 0.1000, 0.9000])