Search code examples
tensorflowkerasdeep-learningpytorchtraining-data

Training of multi-headed neural network with labels only for certain heads at a time


I am trying to train NN with 3 heads sharing some initial layers. However each of my training targets has only output for 2 of them.

I would like to create separate batches with samples that contains output only for the same heads and use them to update only respective heads.

Is there any way how to achieve this in any DL framework?


Solution

  • As your question is somewhat general, I will answer assuming you are using PyTorchLightning. I suggest you use a model that looks like this:

    class MyModel(LightningModule):
      def training_step(self, batch: MyMultiTaskBatch):
        backbone_output = self.backbone(batch.x)
        head = self.heads[batch.task_name]
        head_output = head(backbone_output)
        loss = self.losses[batch.task_name]
        return loss(head_output, batch.y)
    

    Where your batch tells the model which head it should run, and which loss it should use out of dictionaries that map task names to heads and losses. You will also need to implement a dataloader that returns a MyMultiTaskBatch as its batches.