Search code examples
pythonpytorchtensortorch

How to apply function element wise to 2D tensor


Very simple question but I have been struggling with this forever now.

import torch
t = torch.tensor([[2,3],[4,6]])
overlap = [2, 6]
f = lambda x: x in overlap

I want:

torch.tensor([[True,False],[False,True]])

Both the tensor and overlap are very big, so efficiency is wished here.


Solution

  • I found an easy way. Since torch is implemented through numpy array the following works and is performant:

    import torch
    import numpy as np
    t = torch.tensor([[2,3],[4,6]])
    overlap = [2, 6]
    f = lambda x: x in overlap
    mask = np.vectorize(f)(t)
    

    Found here.