Suppose I want to have the general neural network architecture:
---> BinaryOutput_A
/
Input --> nnLayer -
\
---> BinaryOutput_B
An input is put through a neural network layer which then goes to predict two binary variables (i.e., A
is [0 or 1]
and B
is [0 or 1]
.
In pytorch, you can make such a network with:
class NN(nn.Module):
def __init__(self, inputs):
super(NN, self).__init__()
# -- first layer
self.lin = nn.Linear(inputs,10)
# -- firstLayer --> binaryOutputA
self.l2a = nn.Linear(10,2)
# -- firstLayer --> binaryOutputB
self.l2b = nn.Linear(10,2)
def forward(self, inputs):
o = self.lin(inputs)
o1 = F.log_softmax(self.l2a(o))
o2 = F.log_softmax(self.l2b(o))
return o1, o2
In my train
function, I calculate the loss with loss = loss_function(output, target)
. If that's the case, to properly backpropagate the loss to both the l2a
and l2b
layers using loss.backward()
, could I simply concat
the target
with the proper label for l2a
and l2b
? In that sense, the output would be [outputPredictionA, outputPredictionB]
and I could make the target be [labelA, labelB]
, Would pytorch know to properly assign the loss to each layer?
It turns out that torch is actually really smart, and you can just calculate the total loss as:
loss = 0
loss += loss_function(pred_A, label_A)
loss += loss_function(pred_B, label_B)
loss.backward()
and the error will be properly backpropagated through the network. No torch.cat()
needed or anything.