Search code examples
pythonnumpytensorflowglobal-variables

How to initialize a tf.Variable with a tf.constant or a numpy array?


I am trying to initialize a tf.Variable() in a tf.InteractiveSession(). I already have some pre-trained weights which are individual numpy files. How do I effectively initialize the variable with these numpy values ?

I have gone through the following options:

  1. Using tf.assign()
  2. using sess.run() directly during tf.Variable() creation

Seems like the values are not correctly initialized. Following is some code I have tried. Let me know which is the correct one ?

def read_numpy(file):
    return np.fromfile(file,dtype='f')

def build_network():
    with tf.get_default_graph().as_default():
        x = tf.Variable(tf.constant(read_numpy('foo.npy')),name='var1')
        sess = tf.get_default_session()
        with sess.as_default():
            sess.run(tf.global_variables_initializer())

sess = tf.InteractiveSession()
with sess.as_default():
    build_network()

Is this the correct way to do it ? I have printed the session object, and it is the same session used throughout.

edit : Currently it seems like using sess.run(tf.global_variables_initializer()) is calling a random initialize op


Solution

  • tf.Variable() accepts numpy arrays as initial values:

    import tensorflow as tf
    import numpy as np
    
    init = np.ones((2, 2))
    x = tf.Variable(init) # <-- set initial value to assign to a variable
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer()) # <-- this will assign the init value
        print(x.eval())
    # [[1. 1.]
    #  [1. 1.]]
    

    So just use the numpy array to initialize, no need to convert it to a tensor first.

    Alternatively, you could also use tf.Variable.load() to assign values from numpy array to a variable within a session context:

    import tensorflow as tf
    import numpy as np
    
    x = tf.Variable(tf.zeros((2, 2)))
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        init = np.ones((2, 2))
        x.load(init)
        print(x.eval())
    # [[1. 1.]
    #  [1. 1.]]