Search code examples
mathpytorchlinear-algebratensormatrix-factorization

PyTorch how to factor a matrix until get 1 matrix that doesn't change?


I'm having a matrix A, I'd like to factor it as A = BC while knowing that C is some matrix which isn't changed over multiple samples of A. How do I find B using PyTorch? A is a known matrix.

A1 = B1 C
A2 = B2 C
A3 = B3 C
...

Solution

  • You could define C, add it to an optimizer (eg. SGD), and then minimize some distance (eg. MSELoss) between Ai = Bi @ C at every iteration. Here is a minimal example:

    C = torch.rand(m,n)
    optimizer = torch.optim.SGD([C], lr=0.1)
    
    for a, b in zip(A,B):
        loss = F.mse(a, b@C)
        optimizer.zero_grad()    
        loss.backward()
        optimizer.step()