I'm creating a model using BertModel to identify answer span (without using BertForQA).
I have an indepent linear layer for determining start and end token respectively. In init():
self.start_linear = nn.Linear(h, output_dim)
self.end_linear = nn.Linear(h, output_dim)
In forward(), I output a predicted start layer and predicted end layer:
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask) # input = bert tokenizer encoding
lhs = outputs.last_hidden_state # (batch_size, sequence_length, hidden_size)
out = lhs[:, -1, :] # (batch_size, hidden_dim)
st = self.start_linear(out)
end = self.end_linear(out)
predict_start = self.softmax(st)
predict_end = self.softmax(end)
return predict_start, predict_end
Then in train_epoch(), I tried to backpropagate the losses separately:
def train_epoch(model, train_loader, optimizer):
model.train()
total = 0
st_loss, st_correct, st_total_loss = 0, 0, 0
end_loss, end_correct, end_total_loss = 0, 0, 0
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
start_idx = batch['start'].to(device)
end_idx = batch['end'].to(device)
start, end = model(input_ids=input_ids, attention_mask=attention_mask)
st_loss = model.compute_loss(start, start_idx)
end_loss = model.compute_loss(end, end_idx)
st_total_loss += st_loss.item()
end_total_loss += end_loss.item()
# perform backward propagation to compute the gradients
st_loss.backward()
end_loss.backward()
# update the weights
optimizer.step()
But then I got on the line of end_loss.backward()
:
Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
Am I supposed to do the backward pass separately? Or should I do it in another way? Thank you!
The standard procedure is just to sum both losses and backpropagate on the sum.
It can be important to make sure both losses you want to sum have values that are on average approximately as big, or at least proportional to the importance you want each to have relative to one another(otherwise, the model is going to optimize for the bigger loss more than for the smaller one). In the span detection case, I'm guessing this won't be necessary however due to the apparent symmetry of the problem.