I'm learning about Action-Critic Reinforcement Learning techniques, in particular A2C algorithm.
I've found a good description of a simple version of the algorithm (i.e. without experience replay, batching or other tricks) with implementation here: https://link.medium.com/yi55uKWwV2. The complete code from that article is available on GitHub.
I think I understand ok-ish what's happening here, but to make sure I actually do, I'm trying to reimplement it from scratch using higher-level tf.keras APIs. Where I'm getting stuck is how do I implement training loop correctly, and how do I formulate actor's loss function.
The code I have so far: https://gist.github.com/nevkontakte/beb59f29e0a8152d99003852887e7de7
Edit: I suppose some of my confusion stems from a poor understanding of magic behind gradient computation in Keras/TensorFlow, so any pointers there would be appreciated.
First, credit where credit is due: information provided by ralf htp and Simon was instrumental in helping me to figure out the right answers eventually.
Before I go into detailed answers to my own questions, here's the original code I was trying to rewrite in tf.keras terms, and here's my result.
There is a difference between what raw TF optimizer considers a loss function and what Keras does. When using an optimizer directly, it simply expects a tensor (lazy or eager depending on your configuration), which will be evaluated under tf.GradientTape()
to compute the gradient and update weights.
Example from https://medium.com/@asteinbach/actor-critic-using-deep-rl-continuous-mountain-car-in-tensorflow-4c1fb2110f7c:
# Below norm_dist is the output tensor of the neural network we are training.
loss_actor = -tfc.log(norm_dist.prob(action_placeholder) + 1e-5) * delta_placeholder
training_op_actor = tfc.train.AdamOptimizer(
lr_actor, name='actor_optimizer').minimize(loss_actor)
# Later, in the training loop...
_, loss_actor_val = sess.run([training_op_actor, loss_actor],
feed_dict={action_placeholder: np.squeeze(action),
state_placeholder: scale_state(state),
delta_placeholder: td_error})
In this example it computes the whole graph, including making an inference, capture the gradient and adjust weights. So to pass whatever values you need into the loss function/gradient computation you just pass necessary values into the computation graph.
Keras is a bit more formal in what loss function should look like:
loss: String (name of objective function), objective function or tf.keras.losses.Loss instance. See tf.keras.losses. An objective function is any callable with the signature scalar_loss = fn(y_true, y_pred). If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses.
Keras will do the inference (forward pass) for you and pass the output into the loss function. The loss function is supposed to do some extra computation on the predicted value and y_true
label, and return the result. This whole process will be tracked for the purpose of gradient computation.
Although it is very convenient for traditional training, this is a bit restrictive when we want to pass some extra data in, like TD error. It is possible to work around that and shove all the extra data into y_true
, and pull it apart inside the loss function (I found this trick somewhere on the web, but unfortunately lost the link to source).
Here's how I rewrote the above in the end:
def loss(y_true, y_pred):
action_true = y_true[:, :n_outputs]
advantage = y_true[:, n_outputs:]
return -tfc.log(y_pred.prob(action_true) + 1e-5) * advantage
# Below, in the training loop...
# A trick to pass TD error *and* actual action to the loss function: join them into a tensor and split apart
# Inside the loss function.
annotated_action = tf.concat([action, td_error], axis=1)
actor_model.train_on_batch([scale_state(state)], [annotated_action])
When I asked this question, I didn't understand well enough how TF compute graph works. So the answer is simple: every time sess.run() is invoked, it must compute the whole graph from scratch. Parameters of the distribution would be the same (or similar) as long as graph inputs (e.g. observed state) and NN weights are the same (or similar).
What's wrong is the assumption "the actor's loss function doesn't care about y_pred" :) Actor's loss function involves norm_dist
(which is action probability distribution), which is effectively an analog of y_pred
in this context.