Search code examples
machine-learningpytorchautoencoder

How to discard a branch after training a pytorch model


I am trying to implement a FCN in pytorch with the overall structure as below:

enter image description here

The code so far looks like below:

class SNet(nn.Module):
    def __init__(self):
        super(SNet, self).__init__()
        
        self.enc_a = encoder(...)
        self.dec_a = decoder(...)
        
        self.enc_b = encoder(...)
        self.dec_b = decoder(...)
    
    def forward(self, x1, x2):
        x1 = self.enc_a(x1)
        x2 = self.enc_b(x2)
        x2 = self.dec_b(x2)
        x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
        return x1, x2

In keras it is relatively easy to do this using the functional API. However, I could not find any concrete example / tutorial to do this in pytorch.

  1. How can I discard the dec_a (decoder part of autoencoder branch) after training?
  2. During joint training the loss will be sum (optionally weighted) of the loss from both the branch?

Solution

  • You can also define separate modes for your model for training and inference:

    class SNet(nn.Module):
      def __init__(self):
        super(SNet, self).__init__()
        
        self.enc_a = encoder(...)
        self.dec_a = decoder(...)
        
        self.enc_b = encoder(...)
        self.dec_b = decoder(...)
        
        self.training = True
    
      def forward(self, x1, x2):
        if self.training:
            x1 = self.enc_a(x1)
            x2 = self.enc_b(x2)
            x2 = self.dec_b(x2)
            x1 = self.dec_a(torch.cat((x1, x2), dim=-1)
            return x1, x2
        else:
            x1 = self.enc_a(x1)
            x2 = self.enc_b(x2)
            x2 = self.dec_b(x2)
            return x2
    

    These blocks are examples and may not do exactly what you want because I think there is a bit of ambiguity between how you define the training and inference operations in your block chart vs. your code, but in any case you get the idea of how you can use some modules only during training mode. Then you can just set this variable accordingly.