Search code examples
pythontensorflowtraining-data

Basic Tensorflow Example - Prediction of a Line


I'm trying to create this super simple example with Tensorflow and I clearly don't fully understand the API for Tensorflow.

I have the following code. It's not mine originally - I found it from some demo, but I can't remember where I found it, or else I would give the author credit. Apologies.

Saving the Trained Line Model

import tensorflow as tf
import numpy as np

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# Try to find values for W and b that compute y_data = W * x_data + b
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.global_variables_initializer()

# Create a session saver
saver = tf.train.Saver()

# Launch the graph.
sess = tf.Session() 

sess.run(init)

# Fit the line.
for step in range(201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))
        saver.save(sess, 'linemodel')

Ok that's all fine. I just want to load in the model and then query my model to get a predicted value. Here is my attempted code:

Loading and Querying the Trained Line Model

# This is going to load the line model
import tensorflow as tf

sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
for v in all_vars:
    v_ = sess.run(v)
    print("This is {} with value: {}".format(v.name, v_))
    # this works


# None of the below works
# Tried this as well
#fetches = {
#   "input": tf.constant(10, name='input')
#}

#feed_dict = {"input": tf.constant(10, name='input')}
#vals = sess.run(fetches, feed_dict = feed_dict)
# Tried this and it didn't work
# query_value = tf.constant(10, name='query')

# print(sess.run(query_value))

This is a really basic question, but how can I just pass in a value and use my line almost like a function. Do I need to change the way the line model is being constructed? My guess is that the computation graph is not set up where the output is an actual variable that we can get. Is this correct? If so, how should I modify this program?


Solution

  • You have to create tensorflow graph again and load saved weights into it. I added couple of lines to your code and it gives desired outputs. Please check it.

    import tensorflow as tf
    import numpy as np
    
    sess = tf.Session() 
    new_saver = tf.train.import_meta_graph('linemodel.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
    all_vars = tf.global_variables()
    
    # load saved weights into new variables
    W = all_vars[0]
    b = all_vars[1]
    
    # build TF graph
    x = tf.placeholder(tf.float32)
    y = tf.add(tf.multiply(W,x),b)
    
    # Session
    init = tf.global_variables_initializer()
    print(sess.run(all_vars))
    sess.run(init)    
    for i in range(2):
        x_ip = np.random.rand(10).astype(np.float32) # batch_size : 10
        vals = sess.run(y,feed_dict={x:x_ip})
        print vals
    

    Output:

    [array([ 0.1000001], dtype=float32), array([ 0.29999995], dtype=float32)]
    
    [-0.21707924 -0.18646611 -0.00732027 -0.14248954 -0.54388255 -0.33952206  -0.34291503 -0.54771954 -0.60995424 -0.91694558]
    [-0.45050886 -0.01207681 -0.38950539 -0.25888413 -0.0103816  -0.10003483 -0.04783082 -0.83299863 -0.53189355 -0.56571382]
    

    I hope this helps.