Search code examples
tensorflowjavacpp

Run importing TensorFlow graph fails for uninitialized variables


I'm attempting to run TensorFlow training in java by using javacpp-presets for TensorFlow. I've generated a .pb file by using tf.train.write_graph(sess.graph_def, '.', 'example.pb', as_text=False) as below.

import tensorflow as tf
import numpy as np

x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
Weights = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='Weights')
biases = tf.Variable(tf.zeros([1]), name='biases')
y = Weights * x_data + biases
loss = tf.reduce_mean(tf.square(y - y_data)) #compute the loss
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()

with tf.Session() as sess:
   print(sess.run(Weights), sess.run(biases))
   tf.train.write_graph(sess.graph_def, '.', 'example.pb', as_text=False)

I got:

Exception in thread "main" java.lang.Exception: Attempting to use uninitialized value Weights"

when I run:

tensorflow.Status s = session.Run(new StringTensorPairVector(new String[] {}, new Tensor[] {}), new tensorflow.StringVector(), new tensorflow.StringVector("train"), outputs);  

after loading the graph,tensorflow.ReadBinaryProto(Env.Default(), "./example.pb", def);

Is there any javacpp-presets api to do the same work as init = tf.global_variables_initializer()?
Or any C++ TensorFlow api I can use to initialize all variable?


Solution

  • In your Python program, init (the result of tf.global_variables_initializer()) is a tf.Operation that, when passed to sess.run(). If you capture the value of init.name when building the Python graph, you can pass that name to session.Run() in your Java program before running the training step.

    I'm not 100% sure what the API for javacpp-presets looks like, but I think you would be able to do this as:

    tensorflow.Status s = session.Run(
        new StringTensorPairVector(new String[] {}, new Tensor[] {}),
        new tensorflow.StringVector(),
        new tensorflow.StringVector(value_of_init_dot_name),
        outputs);  
    

    ...where value_of_init_dot_name is the value of init.name you obtained from the Python program.