Search code examples
tensorflowdeep-learningtensorflow2.0reinforcement-learning

using gather on argmax is different than taking max


I'm trying to learn to train a double-DQN algorithm on tensorflow and it doesn't work. to make sure everything is fine I wanted to test something. I wanted to make sure that using tf.gather on the argmax is exactly the same as taking the max: let's say I have a network called target_network:

first let's take the max:

next_qvalues_target1 = target_network.get_symbolic_qvalues(next_obs_ph) #returns tensor of qvalues
next_state_values_target1 = tf.reduce_max(next_qvalues_target1, axis=1)

let's try it in a different way- using argmax and gather:

next_qvalues_target2 = target_network.get_symbolic_qvalues(next_obs_ph) #returns same tensor of qvalues
chosen_action = tf.argmax(next_qvalues_target2, axis=1)
next_state_values_target2 = tf.gather(next_qvalues_target2, chosen_action)

diff = tf.reduce_sum(next_state_values_target1) - tf.reduce_sum(next_state_values_target2)

next_state_values_target2 and next_state_values_target1 are supposed to be completely identical. so running the session should output diff = . but it does not.

What am I missing?

Thanks.


Solution

  • Found out what went wrong. chosen action is of shape (n, 1) so I thought that using gather on a variable that's (n, 4) I'll get a result of shape (n, 1). turns out this isn't true. I needed to turn chosen_action to be a variable of shape (n, 2)- instead of [action1, action2, action3...] I needed it to be [[1, action1], [2, action2], [3, action3]....] and use gather_nd to be able to take specific elements from next_qvalues_target2 and not gather, because gather takes complete rows.