Search code examples
pythonnlppytorchbert-language-model

How can i detect if a callback is triggered in pytorch?


I am fine-tuning a BERT model. First, I want to freeze layers and train a bit. When a certain callback is triggered (let's say ReduceLROnPlateau) I want to unfreeze layers. How can I do it?


Solution

  • I'm afraid learning rate schedulers in PyTorch don't provide hooks. Looking at the implementation of ReduceLROnPlateau here, two properties are reset when the scheduler is triggered (i.e. when it identifies a plateau and reduces the learning rate):

        if self.num_bad_epochs > self.patience:
            self._reduce_lr(epoch)
            self.cooldown_counter = self.cooldown
            self.num_bad_epochs = 0
    

    Based on that, you could wrap your scheduler step call and find out if _reduce_lr was triggered by checking that both scheduler.cooldown_counter == scheduler.cooldown and scheduler.num_bad_epochs == 0 are true.