Search code examples
pythonpytorch

Checking if a tensors values are contained in another tensor


I have a torch tensor like so:

a=[1, 234, 54, 6543, 55, 776]

and other tensors like so:

b=[234, 54]
c=[55, 776]

I want to create a new mask tensor where the values of a will be true if there is another tensor (b or c) are equal to it.

For example, in the tensors we have above I would like to create the following masking tensor:

a_masked =[False, True, True, False, True, True]
# The first two True values correspond to tensor `b` while the last two True values 
correspond to tensor `c`.

I have seen other methods to check whether a full tensor is contained in another but this isn't the case here.

Is there a torch way to do this efficiently? Thanks!


Solution

  • Based on the answers to on the PyTorch forum here, you could explicitly use a for loop, e.g.,

    import torch
    
    a = torch.tensor([1, 234, 54, 6543, 55, 776])
    b = torch.tensor([234, 54])
    c = torch.tensor([55, 776])
    
    a_masked = sum(a == i for i in b).bool() + sum(a == i for i in c).bool()
    
    print(a_masked)
    tensor([False,  True,  True, False, True, True])
    

    However, there is actually a PyTorch isin function, for which you could do:

    a_masked = torch.isin(a, torch.cat([b, c]))
    

    This is several times faster than the sum method.