Search code examples
pythonnumpypytorchtorch

How to implement in pytorch - numpy's .unique WITH(!) return_index = True?


In numpy.unique there is an option return_index=True - which returns positions of unique elements (first occurrence - if several).

Unfortunately, there is no such option in torch.unique !

Question: What are the fast and torch-style ways to get indexes of the unique elements ?

=====================

More generally my issue is the following: I have two vectors v1, v2, and I want to get positions of those elements in v2, which are not in v1 and also for repeated elements I need only one position. Numpy's unique with return_index = True immediately gives the solution. How to do it in torch ? If we know that vector v1 is sorted, can it be used to speed up the process ?


Solution

  • You can achieve this in PyTorch with the following approach:

    def get_unique_elements_first_idx(tensor):
        # sort tensor
        sorted_tensor, indices = torch.sort(tensor)
        # find position of jumps
        unique_mask = torch.cat((torch.tensor([True]), sorted_tensor[1:] != sorted_tensor[:-1]))
        return indices[unique_mask]
    

    Example usage:

    v1 = torch.tensor([2, 3, 3])
    v2 = torch.tensor([1, 2, 6, 2, 3, 10, 4, 6, 4])
    
    # Mask to find elements in v2 that are not in v1
    mask = ~torch.isin(v2, v1)
    v2_without_v1 = v2[mask]
    
    # Get unique elements and their first indices
    unique_indices = get_unique_elements_first_idx(v2_without_v1)
    
    print(unique_indices)           #[0, 3, 1, 2]
    print(v2[mask][unique_indices]) #[1, 4, 6, 10]
    

    P.S. On my computer, the function processes a vector of size 10 million in about (1.1 ± 0.1)s.