Search code examples
arraysnumpypytorchtensorisin

Pytorch: Test each row of the first 2D tensor also exist in the second tensor?


Given two tensors t1 and t2:

t1=torch.tensor([[1,2],[3,4],[5,6]])
t2=torch.tensor([[1,2],[5,6]])

If the row elements of t1 is exist in t2, return True, otherwise return False. The ideal result is [Ture, False, True]. I tried torch.isin(t1, t2), but its return the results by elements not by rows. By the way, if they are numpy arrays, it can be completed by

np.in1d(t1.view('i,i').reshape(-1), t2.view('i,i').reshape(-1))

I wonder how to get the similar result in tensor?


Solution

  • def rowwise_in(a,b):
      """ 
      a - tensor of size a0,c
      b - tensor of size b0,c
      returns - tensor of size a1 with 1 for each row of a in b, 0 otherwise
      """
      
      # dimensions
      a0 = a.shape[0]
      b0 = b.shape[0]
      c  = a.shape[1]
      assert c == b.shape[1] , "Tensors must have same number of columns"
    
      a_expand = a.unsqueeze(1).expand(a0,b0,c)
      b_expand = b.unsqueeze(0).expand(a0,b0,c)
    
      # element-wise equality
      equal = a_expand == b_expand
    
      # sum along dim 2 (all elements along this dimension must be true for the summed dimension to be True)
      row_equal = torch.prod(equal,dim = 2)
    
      row_in_b = torch.max(row_equal, dim = 1)[0]
      return row_in_b