Search code examples
pythontensorflowdebuggingvariable-assignmenttensorflow1.15

Error assigning variable values in a TensorFlow model


I have a TensorFlow model that I have loaded from a repository as

model = tf.saved_model.load(folder)

My objective is to replicate this same model in Jax, and for so I need to understand whether the variable values (weights and biases) loaded are the correct ones.

One way I can recover the value of variable i is just

vars = model.variables
print(vars[i].numpy())

If I assign these values into the Jax network, however, I do not recover the right results, so in order to debug, I am trying to analyze the output of specific layers. To do so I need to make sure that weights and biases are the same, eg by assigning them previously. Specifically, if I do

numpy_vars = [v.numpy() for v in vars] # This is done in eager mode.

with tf.compat.v1.Session(graph = graph) as sess:
    tvars = tf.compat.v1.trainable_variables()
    tf.compat.v1.variables_initializer(vars).run() #Necessary init. of either tvars/vars
    for v, tv in zip(numpy_vars, tvars):
        tv.assign(v)
    print(tvars[0].eval()) # This returns the value of the variable in graph mode.
    print('------------------------------')
    print(numpy_vars[0])

It seems to not be returning the same value, which I expected, although both have the same shape. I am wondering whether this might be because there are initialization operations in the model.graph, but am not quite sure. If I instead change the line

tv.assign(v)

with

sess.run(tv.assign(v))

I get error

TypeError: Argument `fetch` = <tf.Variable 'UnreadVariable' shape=(11, 256) dtype=float32> has invalid type "_UnreadVariable" must be a string or Tensor. (Can not convert a _UnreadVariable into a Tensor or Operation.)

Any suggestions of how to assign the values of those variables so that they remain fixed during graph execution?


Solution

  • The answer seems to be this:

    numpy_vars = [v.numpy() for v in vars]
    
    with tf.compat.v1.Session(graph = graph) as sess:
        tvars = tf.compat.v1.trainable_variables()
        tf.compat.v1.variables_initializer(vars).run()
        print(tvars[0].eval())
        print('------------------------------')
        for v, tv in zip(numpy_vars, tvars):
            tf.compat.v1.assign(tv, v).read_value().eval()
        print(tvars[0].eval())
        print('------------------------------')
        print(numpy_vars[0])
    

    After the line

    tf.compat.v1.assign(tv, v).read_value().eval()
    

    I have checked that the weights and biases work appropriately.