pythonpytorchtorch

Checking if a tensors values are contained in another tensor

I have a torch tensor like so:

``````a=[1, 234, 54, 6543, 55, 776]
``````

and other tensors like so:

``````b=[234, 54]
c=[55, 776]
``````

I want to create a new mask tensor where the values of `a` will be true if there is another tensor (`b` or `c`) are equal to it.

For example, in the tensors we have above I would like to create the following masking tensor:

``````a_masked =[False, True, True, False, True, True]
# The first two True values correspond to tensor `b` while the last two True values
correspond to tensor `c`.
``````

I have seen other methods to check whether a full tensor is contained in another but this isn't the case here.

Is there a torch way to do this efficiently? Thanks!

Solution

• Based on the answers to on the PyTorch forum here, you could explicitly use a for loop, e.g.,

``````import torch

a = torch.tensor([1, 234, 54, 6543, 55, 776])
b = torch.tensor([234, 54])
c = torch.tensor([55, 776])

a_masked = sum(a == i for i in b).bool() + sum(a == i for i in c).bool()

However, there is actually a PyTorch `isin` function, for which you could do:
``````a_masked = torch.isin(a, torch.cat([b, c]))
This is several times faster than the `sum` method.