Search code examples
backpropagationpytorch

Neural network in pytorch predict two binary variables


Suppose I want to have the general neural network architecture:

                     ---> BinaryOutput_A
                    / 
Input --> nnLayer -
                    \
                     ---> BinaryOutput_B

An input is put through a neural network layer which then goes to predict two binary variables (i.e., A is [0 or 1] and B is [0 or 1].

In pytorch, you can make such a network with:

class NN(nn.Module):

    def __init__(self, inputs):        
        super(NN, self).__init__()

        # -- first layer
        self.lin = nn.Linear(inputs,10)        

        # -- firstLayer --> binaryOutputA
        self.l2a = nn.Linear(10,2)

        # -- firstLayer --> binaryOutputB
        self.l2b = nn.Linear(10,2)

    def forward(self, inputs):
        o = self.lin(inputs)
        o1 = F.log_softmax(self.l2a(o))
        o2 = F.log_softmax(self.l2b(o))        
        return o1, o2

In my train function, I calculate the loss with loss = loss_function(output, target). If that's the case, to properly backpropagate the loss to both the l2a and l2b layers using loss.backward(), could I simply concat the target with the proper label for l2a and l2b? In that sense, the output would be [outputPredictionA, outputPredictionB] and I could make the target be [labelA, labelB], Would pytorch know to properly assign the loss to each layer?


Solution

  • It turns out that torch is actually really smart, and you can just calculate the total loss as:

    loss = 0
    loss += loss_function(pred_A, label_A)
    loss += loss_function(pred_B, label_B)
    
    loss.backward()
    

    and the error will be properly backpropagated through the network. No torch.cat() needed or anything.