Search code examples
javaalgorithmneural-networkdeep-learningq-learning

How to implement Deep Q-learning gradient descent


So I'm trying to implement Deep Q-learning algorithm created by Google DeepMind and I think I have got a pretty good hang of it now. Yet there is still one (pretty important) thing I don't really understand and I hope you could help.

Doesn't yj result to a double (Java) and the latter part to a matrix containing Q-values for each action in current state in the following line (4th last line in the algorithm):


1


So how can I subtract them from each other.

Should I make yj a matrix containing all the data from here 2 except replace the currently selected action with

enter image description here

This doesn't seem like the right answer and I'm a bit lost here as you can see.


enter image description here


Solution

  • Actually found it myself. (Got it right from the start :D)

    1. Do a feedforward pass for the current state s to get predicted Q-values for all actions.
    2. Do a feedforward pass for the next state s’ and calculate maximum overall network outputs max a’ Q(s’, a’).
    3. Set Q-value target for action to r + γmax a’ Q(s’, a’) (use the max calculated in step 2). For all other actions, set the Q-value target to the same as originally returned from step 1, making the error 0 for those outputs.
    4. Update the weights using backpropagation.