Search code examples
pythonnumpypytorchtorch

How can I zero out duplicate values in each row of a PyTorch tensor?


I would like to write a function that achieves the behavior described in this question.

That is, I want to zero out duplicate values in each row of a matrix in PyTorch. For example, given a matrix

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
              [1, 6, 3, 5, 3, 5, 4]])

I would like to get

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 0, 0, 4]])

or

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
              [1, 6, 3, 5, 4, 0, 0]])

According to the linked question, torch.unique() alone is not sufficient. I want to know how to implement this function without a loop.


Solution

  • x = torch.tensor([
        [1, 2, 3, 4, 3, 3, 4],
        [1, 6, 3, 5, 3, 5, 4]
    ], dtype=torch.long)
    
    # sorting the rows so that duplicate values appear together
    # e.g., first row: [1, 2, 3, 3, 3, 4, 4]
    y, indices = x.sort(dim=-1)
    
    # subtracting, so duplicate values will become 0
    # e.g., first row: [1, 2, 3, 0, 0, 4, 0]
    y[:, 1:] *= ((y[:, 1:] - y[:, :-1]) !=0).long()
    
    # retrieving the original indices of elements
    indices = indices.sort(dim=-1)[1]
    
    # re-organizing the rows following original order
    # e.g., first row: [1, 2, 3, 4, 0, 0, 0]
    result = torch.gather(y, 1, indices)
    
    print(result) # => output
    

    Output

    tensor([[1, 2, 3, 4, 0, 0, 0],
            [1, 6, 3, 5, 0, 0, 4]])