I have a simple recurrent network example, with a tf.Saver
and weight
, bias
and state
variables being saved.
When the example is run with no options, it will initialise the state vector to contain zeros, but I want to pass a load_model
option and it to use the last values of the state vector to as a feed for the session.run
invocation.
All documentation I see insists that one must invoke session.run
to retrieve stored values from variables, but in this case I want to retrieve the values so that I can initialise the state variable. Do I need to do a separate graph just to retrieve the initialization values?
Example code below:
import tensorflow as tf
import math
import numpy as np
INPUTS = 10
HIDDEN_1 = 2
BATCH_SIZE = 3
def batch_vm2(m, x):
[input_size, output_size] = m.get_shape().as_list()
input_shape = tf.shape(x)
batch_rank = input_shape.get_shape()[0].value - 1
batch_shape = input_shape[:batch_rank]
output_shape = tf.concat(0, [batch_shape, [output_size]])
x = tf.reshape(x, [-1, input_size])
y = tf.matmul(x, m)
y = tf.reshape(y, output_shape)
return y
def get_weight_and_biases():
with tf.variable_scope(network_scope, reuse = True) as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
return weights, biases
def get_saver():
with tf.variable_scope('h1') as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
saver = tf.train.Saver([weights, biases, state])
return saver, scope
def load(sess, saver, checkpoint_dir = './'):
print("loading a session")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("no checkpoint found")
return
iteration = None
def iterate_state(prev_state_tuple, input):
with tf.variable_scope(network_scope, reuse = True) as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
print("input: ",input.get_shape())
matmuladd = batch_vm2(weights, input) + biases
matmulpri = tf.Print(matmuladd,[matmuladd, weights], message=" malmul -> %i, weights " % iteration)
print("prev state: ",prev_state_tuple.get_shape())
unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
prev_state = 0.99* unpacked_state
prev_state = tf.Print(prev_state, [unpacked_state, matmuladd], message=" -> prevstate, matmulpri ")
state = state.assign( prev_state + 0.01*matmulpri )
#output = tf.nn.relu(state)
output = tf.nn.tanh(state)
state = tf.Print(state, [state], message=" state -> ")
output = tf.Print(output, [output], message=" output -> ")
print(" state: ", state.get_shape())
print(" output: ", output.get_shape())
concat_result = tf.concat(0,[state, output])
print (" concat return: ", concat_result.get_shape())
return concat_result
def data_iter():
while True:
idxs = np.random.rand(BATCH_SIZE, INPUTS)
yield idxs
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean('load_model', False, 'If true, uses model files '
'to restore.')
network_scope = None
with tf.Graph().as_default():
inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
iteration = -1
saver, network_scope = get_saver()
initial_state = tf.placeholder(tf.float32, shape=(HIDDEN_1))
initial_out = tf.zeros([HIDDEN_1],
name='initial_out')
concat_tensor = tf.concat(0,[initial_state, initial_out])
print(" init state: ",initial_state.get_shape())
print(" init out: ",initial_out.get_shape())
print(" concat: ",concat_tensor.get_shape())
scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
print ("scanout shape: ", scanout.get_shape())
state, output = tf.split(1,2,scanout, name='split_scan_output')
print(" end state: ",state.get_shape())
print(" end out: ",output.get_shape())
sess = tf.Session()
# Run the Op to initialize the variables.
sess.run(tf.initialize_all_variables())
tf.train.write_graph(sess.graph_def, './tenIrisSave/logsd','graph.pbtxt')
tf_weight, tf_bias = get_weight_and_biases()
tf.histogram_summary('weights', tf_weight)
tf.histogram_summary('bias', tf_bias)
tf.histogram_summary('state', state)
tf.histogram_summary('out', output)
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter('./tenIrisSave/summary',sess.graph_def)
if FLAGS.load_model:
load(sess, saver)
# HOW DO I LOAD restored state values??????
#st = state[BATCH_SIZE - 1,:]
#st = sess.run([state], feed_dict={})
print("LOADED last state vec: ", st)
else:
st = np.array([0.0 , 0.0])
iter_ = data_iter()
for i in xrange(0, 1):
print ("iteration: ",i)
iteration = i
input_data = iter_.next()
out,st,so,summary_str = sess.run([output,state,scanout,summary_op], feed_dict={ inputs: input_data, initial_state: st })
saver.save(sess, 'my-model', global_step=1+i)
summary_writer.add_summary(summary_str, i)
summary_writer.flush()
print("input vec: ", input_data)
print("state vec: ", st)
st = st[-1]
print("last state vec: ", st)
print("output vec: ", out)
print(" end state (runtime): ",st.shape)
print(" end out (runtime): ",out.shape)
print(" end scanout (runtime): ",so.shape)
note at lines 124-126 the commented lines for ways that I tried to initialise the feed dictionary values. None of them work.
You have two placeholders:
inputs
initial_state
From what I understand you want to either (depending on FLAGS.load_model
):
Use an initial state full of zeros
Use the last row on state
, which is a Tensor in the graph depending on both placeholders.
With this analysis done, my first hypothesis is that the error just comes from the fact that you use another tensor named state
in the line:
state, output = tf.split(1,2,scanout, name='split_scan_output')
So TensorFlow will try to retrieve this state
, which depends on both placeholders, instead of retrieving the value of the Variable state
you want. Just rename the second one and it might work.
You can try:
if FLAGS.load_model:
load(sess, saver)
with tf.variable_scope('h1', reuse=True)
state_saved = tf.get_variable('state')
st = sess.run(state_saved)