Search code examples
pythondeep-learningneural-networkpytorchconv-neural-network

Constrained Linear combination of learned parameters is pytorch?


I have three tensors X,Y,Z and I want to learn the optimal convex combination of these tensors wrt to some cost, i.e.

aX + bY + cZ such that a + b + c = 1. How can I do this easily in Pytorch?

I know that I could just concatenate along an unsqueezed axis and then apply linear layer as so:

X = X.unsqueeze(-1)
Y = Y.unsqueeze(-1)
Z = Z.unsqueeze(-1)
W = torch.cat([X,Y,Z], dim = -1)   #third axis has dimension  3)
W = torch.linear(3,1)(W)

but this would not apply the convex combination constraint...


Solution

  • I found an answer that works well for those who are interested this would generalize to a linear combination of N tensors you just need to change the weights dim and number of tensors you concatenate.

    weights = nn.Parameter(torch.rand(1,3))
    X = X.unsqueeze(-1)
    Y = Y.unsqueeze(-1)
    Z = Z.unsqueeze(-1)
    weights_normalized = nn.functional.softmax(weights, dim=-1)
    output = torch.matmul(torch.cat([X, Y, Z], dim=-1), weights_normalized.t()).squeeze()