I am fine-tuning a BERT model. First, I want to freeze layers and train a bit. When a certain callback is triggered (let's say ReduceLROnPlateau
) I want to unfreeze layers. How can I do it?
I'm afraid learning rate schedulers in PyTorch don't provide hooks. Looking at the implementation of ReduceLROnPlateau
here, two properties are reset when the scheduler is triggered (i.e.
when it identifies a plateau and reduces the learning rate):
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
Based on that, you could wrap your scheduler step call and find out if _reduce_lr
was triggered by checking that both scheduler.cooldown_counter == scheduler.cooldown
and scheduler.num_bad_epochs == 0
are true.