Search code examples
pythondictionarypytorchmappingtensor

Map each element of torch.Tensor with it's value in the dict


Suppose i have a tensor t consisting only zeros and ones:

t = torch.Tensor([1, 0, 0, 1])

And a dict with the weights:

weights = {0: 0.1, 1: 0.9}

I want to form a new tensor new_t, such that every element in tensor t is mapped to the corresponding value in the dict weights:

new_t = torch.Tensor([0.9, 0.1, 0.1, 0.9])

Is there an elegant way to do this without iterating over tensor t? I've heard about torch.apply, but it only works if tensor t is on the CPU, is there any other options?


Solution

  • If you convert your weights dict into a tensor, you can index directly

    t = torch.tensor([1, 0, 0, 1])
    weights = torch.tensor([0.1, 0.9])
    
    new_t = weights[t]
    new_t
    >tensor([0.9000, 0.1000, 0.1000, 0.9000])