For example, there is a PyTorch matrix A
:
A = tensor([[3,2,1],[1,0,2],[2,2,0]])
I need to replace 0 with 1 on the diagonal, so the result should be:
tensor([[3,2,1],[1,1,2],[2,2,1]])
You can use torch's inbuilt diagonal functions to replace diagonal elements like so:
mask = A.diagonal() == 0
A += torch.diag(mask)
>>> A
tensor([[3, 2, 1],
[1, 1, 2],
[2, 2, 1]])
If you want to replace 0's with another value, change mask
to mask * replace_value
.