Search code examples

tf.function converts variable to tensor automatically

It seems that a tf.function is not able to return a tf.variable. For example, the code below fails on the second call to run_episode.

def run_episode(x,actions) : 

    actions = actions.scatter_nd_update([[x]],[[action]])

    return actions

actions = tf.Variable(tf.zeros((max_mem_size,1),dtype=tf.int32))

actions = run_episode(1,actions)
actions = run_episode(2,actions)

I need to keep the memory of all the moves taken in a game. I've tried TensorArray and now Variable, but tensorflow seems unable to manage a memory variable within a graph function.

Anyone know how I can make run_episode return a tf.Variable?


  • how about using stack and unstack?

    max_mem_size = 5
    action = 3
    def run_episode(x,actions) : 
        a_list = tf.unstack(actions)
        a_list[x] = action
        actions = tf.stack(a_list)
        return actions
    actions = tf.cast(tf.zeros(max_mem_size,1), tf.int32)
    actions = run_episode(1, actions)
    actions = run_episode(2, actions)