Search code examples
pytorchtorchtensor

How do I create a torch diagonal matrices with different element in each batch?


I want to create a tensor like

 tensor([[[1,0,0],[0,1,0],[0,0,1]],[[2,0,0],[0,2,0],[0,0,2]]]])

That is, when a torch tensor B of size (1,n) is given, I want to create a torch tensor A of size (n,3,3) such that A[i] is an B[i] * (identity matrix of size 3x3).

Without using 'for sentence', how do I create this?


Solution

  • Use torch.einsum (Einstein's notation of sum and product)

    A = torch.eye(3)
    b = torch.tensor([1.0, 2.0, 3.0])
    torch.einsum('ij,k->kij', A, b)
    

    Will return:

    tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]],
    
        [[2., 0., 0.],
         [0., 2., 0.],
         [0., 0., 2.]],
    
        [[3., 0., 0.],
         [0., 3., 0.],
         [0., 0., 3.]]])