Search code examples
pythontensorflowreinforcement-learning

scatter update tensor with index obtained using argmax


I'm trying to update a tensor's max value with another value, like so:

actions = tf.argmax(output, axis=1)
gen_targets = tf.scatter_nd_update(output, actions, q_value)

I'm getting an error: AttributeError: 'Tensor' object has no attribute 'handle' on scatter_nd_update.

The output and actions are placeholders declared as:

output = tf.placeholder('float', shape=[None, num_action])
reward = tf.placeholder('float', shape=[None])

What am I doing wrong and what would be the correct way to achieve this?


Solution

  • You are trying to update the value of output which is of type tf.placeholder. Placeholders are immutable objects, you cannot update the value of the placeholder. The tensor you are trying to update should be of type of a variable, e.g. tf.Variable, in order for tf.scatter_nd_update() to be able to update its value. One way to solve this could be to create a variable and then assign the value of the placeholder to the variable using tf.assign(). Since one of the dimensions of the placeholder is None and may be of arbitrary size during runtime, you may want to set validate_shape argument of tf.assign() to False, this way the shape of the placeholder does not need to match the shape of the variable. After the assignment, the shape of var_output will match the actual shape of the object that was fed via the placeholder.

    output = tf.placeholder('float', shape=[None, num_action])
    # dummy variable initialization
    var_output = tf.Variable(0, dtype=output.dtype)
    
    # assign value of placeholder to the var_output
    var_output = tf.assign(var_output, output, validate_shape=False)
    # ...
    gen_targets = tf.scatter_nd_update(var_output, actions, q_value)
    # ...
    sess.run(gen_targets, feed_dict={output: feed_your_placeholder_here})