Search code examples
deep-learningreinforcement-learningopenai-gymq-learning

How does Deep Q learning work


When I am training my model I have the following segment:

s_t_batch, a_batch, y_batch = train_data(minibatch, model2)
# perform gradient step
loss.append(model.train_on_batch([s_t_batch, a_batch], y_batch))

where s_t, a_ corresponds to current states and actions that were taken in those states respectively. model2 is the same as model except that model2 has an output of num_actions and model only outputs the value of the action that was taken in that state.

What I find strange (and is really the focus of this question) is in the function train_data I have the line:

y_batch = r_batch + GAMMA * np.max(model.predict(s_t_batch), axis=1)

The strange part is the fact that I am using the model to generate my y_batch as well as training on them. Doesn't this become some sort of self fulfilling prophecy? If I understand correctly, the model tries to predict the expected maximum reward. Using the same model to try and generate y_batch is implying that it is the true model doesn't it?

The question is, 1. what is the intuition behind using the same model to generate y_batch as it is to train them. 2. (optional) does loss value mean anything. When I plot it, it seems doesn't seem to be converging, however the sum of rewards seem to be increasing (see plots in link below).

The full code can be found here, which is an implementation of Deep Q Learning on the CartPole-v0 problem:

Comments from other forums:

  1. y = r + gamma*np.max(model.predict(s_t_batch), axis=1) is totally natural and y will converge to the true state-action value. And if you don't break down the correlation between consecutive updates with something like experience replay (or better prioritized exp replay) your model WILL diverge. And there are better variants like DDQN, Duelling Network which performs better.
  2. y_batch includes the reward. Both the target and online networks are estimates. It is indeed a somewhat self fulfilling prophecy as DQN's value function is overly optimistic. That is why Double DQN was added a few months later.
  3. y will converge, but not necessarily to the true (I assume you mean optimal) state-action value. No one has proven that the converged value is the optimal value but it is the best approximation we have. However will converge to the the true value for simple enough problems (e.g. grid-world)

Solution

  • The fact that the model trains on its own predictions is the whole point of Q-learning: it is a concept called bootstrapping, which means reusing your experience. The insight behind this is:

    • The Agent is initialized with some weights
    • These weights represent the Agent's current representation of the Q-Value function it is trying to approximate
    • Then it acts on the environment, performing the action it believes to be of highest Q-Value (with some randomness for exploration)
    • Then it receives some feedback from the environment : a reward, and the new state it is in
    • By comparing the difference between the Agent's Q-Value approximation for state t (= [s_t_batch, a_batch]) and it's (discounted) approximation for state t+1 plus the reward (=y_batch), it is able to measure how wrong it's prediction for Qt is.
    • From this measure of mistake (called TD-Error) weights are updated in the direction lower MSE, as for any other gradient-based optimization.
    • (One could wait for more than one step to have more information from the environment to update the weights in an even better direction. One could actually wait for the whole episode to be over and train on that. This continuum between training instantly and waiting for the end is called TD(Lambda), you should look into it)

    Your loss means exactly this: for one batch, it is the mean-squared error between your model's prediction for time t from its sole Q-Value approximation and its prediction for time t from its Q-Value approximation for the next state and taking into account some "ground truth" from the environment, that is the reward for this timestep.

    Your loss does go down it seems to me, it is however very unstable, which is a known issue of vanilla Q-Learning especially vanilla Deep Q-Learning. Look at the overview paper below to have an idea of how more complex algorithms work

    I advise you to look into Temporal Difference Learning. Good ressources also are