Search code examples
pythonpytorchdiagonal

How to replace specific values in PyTorch tensor along diagonal?


For example, there is a PyTorch matrix A:

A = tensor([[3,2,1],[1,0,2],[2,2,0]])

I need to replace 0 with 1 on the diagonal, so the result should be:

tensor([[3,2,1],[1,1,2],[2,2,1]])

Solution

  • You can use torch's inbuilt diagonal functions to replace diagonal elements like so:

    mask = A.diagonal() == 0
    A += torch.diag(mask)
    
    >>> A
    tensor([[3, 2, 1],
            [1, 1, 2],
            [2, 2, 1]])
    

    If you want to replace 0's with another value, change mask to mask * replace_value.