It seems that a tf.function
is not able to return a tf.variable
. For example, the code below fails on the second call to run_episode
.
@tf.function
def run_episode(x,actions) :
actions = actions.scatter_nd_update([[x]],[[action]])
return actions
actions = tf.Variable(tf.zeros((max_mem_size,1),dtype=tf.int32))
actions = run_episode(1,actions)
actions = run_episode(2,actions)
I need to keep the memory of all the moves taken in a game. I've tried TensorArray
and now Variable
, but tensorflow seems unable to manage a memory variable within a graph function.
Anyone know how I can make run_episode
return a tf.Variable
?
how about using stack
and unstack
?
max_mem_size = 5
action = 3
@tf.function
def run_episode(x,actions) :
a_list = tf.unstack(actions)
a_list[x] = action
actions = tf.stack(a_list)
return actions
actions = tf.cast(tf.zeros(max_mem_size,1), tf.int32)
actions = run_episode(1, actions)
actions = run_episode(2, actions)