Search code examples
pythonnlppytorchbert-language-modelhuggingface-transformers

calculate two losses in a model and backpropagate twice


I'm creating a model using BertModel to identify answer span (without using BertForQA).

I have an indepent linear layer for determining start and end token respectively. In init():

self.start_linear = nn.Linear(h, output_dim)

self.end_linear = nn.Linear(h, output_dim)

In forward(), I output a predicted start layer and predicted end layer:

 def forward(self, input_ids, attention_mask):

 outputs = self.bert(input_ids, attention_mask) # input = bert tokenizer encoding

 lhs = outputs.last_hidden_state # (batch_size, sequence_length, hidden_size)

 out = lhs[:, -1, :] # (batch_size, hidden_dim)

 st = self.start_linear(out)

 end = self.end_linear(out) 



 predict_start = self.softmax(st)

 predict_end = self.softmax(end)

 return predict_start, predict_end

Then in train_epoch(), I tried to backpropagate the losses separately:

def train_epoch(model, train_loader, optimizer):

 model.train()

 total = 0

 st_loss, st_correct, st_total_loss = 0, 0, 0

 end_loss, end_correct, end_total_loss = 0, 0, 0

 for batch in train_loader:

   optimizer.zero_grad()

   input_ids = batch['input_ids'].to(device)

   attention_mask = batch['attention_mask'].to(device)

   start_idx = batch['start'].to(device)

   end_idx = batch['end'].to(device)

   start, end = model(input_ids=input_ids, attention_mask=attention_mask)


   st_loss = model.compute_loss(start, start_idx)

   end_loss = model.compute_loss(end, end_idx)

   st_total_loss += st_loss.item()

   end_total_loss += end_loss.item()

 # perform backward propagation to compute the gradients

   st_loss.backward()

   end_loss.backward()

 # update the weights

   optimizer.step() 

But then I got on the line of end_loss.backward():

Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

Am I supposed to do the backward pass separately? Or should I do it in another way? Thank you!


Solution

  • The standard procedure is just to sum both losses and backpropagate on the sum.

    It can be important to make sure both losses you want to sum have values that are on average approximately as big, or at least proportional to the importance you want each to have relative to one another(otherwise, the model is going to optimize for the bigger loss more than for the smaller one). In the span detection case, I'm guessing this won't be necessary however due to the apparent symmetry of the problem.