Search code examples
pythonpytorchdiagonal

Replace diagonal elements with vector in PyTorch


I have been searching everywhere for something equivalent of the following to PyTorch, but I cannot find anything.

L_1 = np.tril(np.random.normal(scale=1., size=(D, D)), k=0)
L_1[np.diag_indices_from(L_1)] = np.exp(np.diagonal(L_1))

I guess there is no way to replace the diagonal elements in such an elegant way using Pytorch.


Solution

  • I do not think that such a functionality is implemented as of now. But, you can implement the same functionality using mask as follows.

    # Assuming v to be the vector and a be the tensor whose diagonal is to be replaced
    mask = torch.diag(torch.ones_like(v))
    out = mask*torch.diag(v) + (1. - mask)*a
    

    So, your implementation will be something like

    L_1 = torch.tril(torch.randn((D, D)))
    v = torch.exp(torch.diag(L_1))
    mask = torch.diag(torch.ones_like(v))
    L_1 = mask*torch.diag(v) + (1. - mask)*L_1
    

    Not as elegant as numpy, but not too bad either.