Search code examples
pythonnumpypytorch

Finding non-intersection of two pytorch tensors


Thanks everyone in advance for your help! What I'm trying to do in PyTorch is something like numpy's setdiff1d. For example given the below two tensors:

t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')

The expected output should be (sorted or unsorted):

torch.tensor([9, 12, 5])

Ideally the operations are done on GPU and no back and forth between GPU and CPU. Much appreciated!


Solution

  • if you don't want to leave cuda, a workaround could be:

    t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
    t2 = torch.tensor([1, 24], device = 'cuda')
    indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
    for elem in t2:
        indices = indices & (t1 != elem)  
    intersection = t1[indices]