Search code examples
python-2.7tensorflowcheckpointing

Accessing values from a restored Tensorflow variable


I have a simple recurrent network example, with a tf.Saver and weight, bias and state variables being saved.

When the example is run with no options, it will initialise the state vector to contain zeros, but I want to pass a load_model option and it to use the last values of the state vector to as a feed for the session.run invocation.

All documentation I see insists that one must invoke session.run to retrieve stored values from variables, but in this case I want to retrieve the values so that I can initialise the state variable. Do I need to do a separate graph just to retrieve the initialization values?

Example code below:

import tensorflow as tf
import math
import numpy as np

INPUTS = 10
HIDDEN_1 = 2
BATCH_SIZE = 3

def batch_vm2(m, x):
  [input_size, output_size] = m.get_shape().as_list()

  input_shape = tf.shape(x)
  batch_rank = input_shape.get_shape()[0].value - 1
  batch_shape = input_shape[:batch_rank]
  output_shape = tf.concat(0, [batch_shape, [output_size]])

  x = tf.reshape(x, [-1, input_size])
  y = tf.matmul(x, m)

  y = tf.reshape(y, output_shape)

  return y

def get_weight_and_biases():
    with tf.variable_scope(network_scope, reuse = True) as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
    return weights, biases

def get_saver():
    with tf.variable_scope('h1') as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
        saver = tf.train.Saver([weights, biases, state])
    return saver, scope


def load(sess, saver, checkpoint_dir = './'):

        print("loading a session")
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise Exception("no checkpoint found")
        return

iteration = None

def iterate_state(prev_state_tuple, input):
    with tf.variable_scope(network_scope, reuse = True) as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
        print("input: ",input.get_shape())
        matmuladd = batch_vm2(weights, input) + biases
        matmulpri = tf.Print(matmuladd,[matmuladd, weights], message=" malmul -> %i, weights " % iteration)
        print("prev state: ",prev_state_tuple.get_shape())
        unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
        prev_state = 0.99* unpacked_state
        prev_state = tf.Print(prev_state, [unpacked_state, matmuladd], message=" -> prevstate, matmulpri ")
        state = state.assign( prev_state + 0.01*matmulpri )
        #output = tf.nn.relu(state)
        output = tf.nn.tanh(state)
        state = tf.Print(state, [state], message=" state -> ")
        output = tf.Print(output, [output], message=" output -> ")
        print(" state: ", state.get_shape())
        print(" output: ", output.get_shape())
        concat_result = tf.concat(0,[state, output])
        print (" concat return: ", concat_result.get_shape())
        return concat_result

def data_iter():
    while True:
        idxs = np.random.rand(BATCH_SIZE, INPUTS)
        yield idxs

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean('load_model', False, 'If true, uses model files '
                     'to restore.')


network_scope = None

with tf.Graph().as_default():
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
    iteration = -1
    saver, network_scope = get_saver()
    initial_state = tf.placeholder(tf.float32, shape=(HIDDEN_1))
    initial_out = tf.zeros([HIDDEN_1],
                             name='initial_out')
    concat_tensor = tf.concat(0,[initial_state, initial_out])
    print(" init state: ",initial_state.get_shape())
    print(" init out: ",initial_out.get_shape())
    print(" concat: ",concat_tensor.get_shape())
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
    print ("scanout shape: ", scanout.get_shape())
    state, output = tf.split(1,2,scanout, name='split_scan_output')
    print(" end state: ",state.get_shape())
    print(" end out: ",output.get_shape())


    sess = tf.Session()
    # Run the Op to initialize the variables.

    sess.run(tf.initialize_all_variables())
    tf.train.write_graph(sess.graph_def, './tenIrisSave/logsd','graph.pbtxt')
    tf_weight, tf_bias = get_weight_and_biases()
    tf.histogram_summary('weights', tf_weight)
    tf.histogram_summary('bias', tf_bias)
    tf.histogram_summary('state', state)
    tf.histogram_summary('out', output)
    summary_op = tf.merge_all_summaries()
    summary_writer = tf.train.SummaryWriter('./tenIrisSave/summary',sess.graph_def)
    if FLAGS.load_model:
        load(sess, saver)
        # HOW DO I LOAD restored state values??????
        #st = state[BATCH_SIZE - 1,:]
        #st = sess.run([state], feed_dict={})
        print("LOADED last state vec: ", st)
    else:
        st = np.array([0.0 , 0.0])
    iter_ = data_iter()
    for i in xrange(0, 1):
        print ("iteration: ",i)
        iteration = i
        input_data = iter_.next()
        out,st,so,summary_str = sess.run([output,state,scanout,summary_op], feed_dict={ inputs: input_data, initial_state: st })
        saver.save(sess, 'my-model', global_step=1+i)
        summary_writer.add_summary(summary_str, i)
        summary_writer.flush()
        print("input vec: ", input_data)
        print("state vec: ", st)
        st = st[-1]
        print("last state vec: ", st)
        print("output vec: ", out)
        print(" end state (runtime): ",st.shape)
        print(" end out (runtime): ",out.shape)
        print(" end scanout (runtime): ",so.shape)

note at lines 124-126 the commented lines for ways that I tried to initialise the feed dictionary values. None of them work.


Solution

  • You have two placeholders:

    • inputs
    • initial_state

    From what I understand you want to either (depending on FLAGS.load_model):

    1. Use an initial state full of zeros

      • this is simple, you just feed a numpy array full of zeros
    2. Use the last row on state, which is a Tensor in the graph depending on both placeholders.

      • you just want to load the value from a previous checkpoint

    With this analysis done, my first hypothesis is that the error just comes from the fact that you use another tensor named state in the line:

    state, output = tf.split(1,2,scanout, name='split_scan_output')
    

    So TensorFlow will try to retrieve this state, which depends on both placeholders, instead of retrieving the value of the Variable state you want. Just rename the second one and it might work.

    You can try:

    if FLAGS.load_model:
        load(sess, saver)
        with tf.variable_scope('h1', reuse=True)
            state_saved = tf.get_variable('state')
        st = sess.run(state_saved)