Search code examples
pytorchlearning-rate

Pytorch: looking for a function that let me to manually set learning rates for specific epochs intervals


For example, set lr = 0.01 for the first 100 epochs, lr = 0.001 from epoch 101 to epoch 1000, lr = 0.0005 for epoch 1001-4000. Basically my learning rate plan is not letting it decay exponentially with a fixed number of steps. I know it can be achieved by self-defined functions, just curious if there are already developed functions to do that.


Solution

  • torch.optim.lr_scheduler.LambdaLR is what you are looking for. It returns multiplier of initial learning rate so you can specify any value for any given epoch. For your example it would be:

    def lr_lambda(epoch: int):
        if 100 < epoch < 1000:
            return 0.1
        if 1000 < epoch 4000:
            return 0.05
    
    # Optimizer has lr set to 0.01
    scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
    for epoch in range(100):
        train(...)
        validate(...)
        optimizer.step()
        scheduler.step()
    

    In PyTorch there are common functions (like MultiStepLR or ExponentialLR) but for custom use case (as is yours), LambdaLR is the easiest.