I have three tensors X,Y,Z and I want to learn the optimal convex combination of these tensors wrt to some cost, i.e.
aX + bY + cZ such that a + b + c = 1. How can I do this easily in Pytorch?
I know that I could just concatenate along an unsqueezed axis and then apply linear layer as so:
X = X.unsqueeze(-1)
Y = Y.unsqueeze(-1)
Z = Z.unsqueeze(-1)
W = torch.cat([X,Y,Z], dim = -1) #third axis has dimension 3)
W = torch.linear(3,1)(W)
but this would not apply the convex combination constraint...
I found an answer that works well for those who are interested this would generalize to a linear combination of N tensors you just need to change the weights dim and number of tensors you concatenate.
weights = nn.Parameter(torch.rand(1,3))
X = X.unsqueeze(-1)
Y = Y.unsqueeze(-1)
Z = Z.unsqueeze(-1)
weights_normalized = nn.functional.softmax(weights, dim=-1)
output = torch.matmul(torch.cat([X, Y, Z], dim=-1), weights_normalized.t()).squeeze()