Search code examples
pythontensorflowmachine-learningtensorflow2.0tensorflow-agents

tf-agent, QNetwork => DqnAgent w/ tfa.optimizers.CyclicalLearningRate


Is there an easy native way to implement tfa.optimizers.CyclicalLearningRate w/ QNetwork on DqnAgent?

Trying to avoid writing my own DqnAgent.

I guess the better question might be, what is a proper way to implement callbacks on DqnAgent?


Solution

  • From the tutorial you linked, the part where they set the optimizer is

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    
    train_step_counter = tf.Variable(0)
    
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter)
    
    agent.initialize()
    

    So you can replace optimizer with whatever optimizer you would rather use. Based on the documentation something like

    optimizer = tf.keras.optimizers.Adam(learning_rate=tfa.optimizers.CyclicalLearningRate)
    

    should work, barring any potential compatibility issues coming from that they are using the tf 1.0 adam in the tutorial.