Search code examples
jaxadam

How to set a new learning rate manually in optax optimizer?


I have the following optimizer being create using optax:

def create_optimizer(learning_rate=6.25e-2, beta1=0.4, beta2=0.999,
                     eps=2e-4, centered=False):

  Returns:
    An optax optimizer.
  """
 
    return optax.adam(learning_rate, b1=beta1, b2=beta2, eps=eps)

How during training update this learning rate manually?

I couldn't find any documentation about that.


Solution

  • Disclaimer. Usually, you would use a schedule to adapt the learning rate during training. This answer provides a solution to obtain direct control over the learning rate.


    In general, you can put any optimizer's hyperparmeters (such as the learning rate) into the optimizer's state and then directly mutate the state. Moving the hyperparameters into the state is necessary as optax optimizers are pure functions. Especially, the only way to dynamically change their behaviour is by changing their input.

    Setup. I am using a stochastic gradient descent optimizer to highlight the effect of the learning rate on the update suggested by the optimizer.

    import jax.numpy as jnp
    import optax
    
    # Define example parameters and gradients.
    params, grads = jnp.array([0.0, 0.0]), jnp.array([1.0, 2.0])
    
    # Ensure the learning rate is part of the optimizer's state.
    opt = optax.inject_hyperparams(optax.sgd)(learning_rate=1e-2)
    opt_state = opt.init(params)
    

    Update computation.

    updates, _ = opt.update(grads, opt_state)
    updates
    
    Array([-0.01, -0.02], dtype=float32)
    

    Directly setting the learning rate.

    opt_state.hyperparams['learning_rate'] = 3e-4
    

    Same update computation as before (with new learning rate).

    updates, _ = opt.update(grads, opt_state)
    updates
    
    Array([-0.0003, -0.0006], dtype=float32)
    

    See this discussion for more information.