I am working on a VAE project in TensorFlow where the encoder/decoder networks are build in functions. The idea is to be able to save, then load the trained model and do sampling, using the encoder function.
After restoring the model, I am having trouble getting the decoder function to run and give me back the restored, trained variables, getting an "Uninitialized value" error. I assume it is because the function is either creating a new new one, overwriting the existing, or otherwise. But I cannot figure out how to solve this. Here is some code:
class VAE(object):
def __init__(self, restore=True):
self.session = tf.Session()
if restore:
self.restore_model()
self.build_decoder = tf.make_template('decoder', self._build_decoder)
@staticmethod
def _build_decoder(z, output_size=768, hidden_size=200,
hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid):
x = tf.layers.dense(z, hidden_size, activation=hidden_activation)
x = tf.layers.dense(x, hidden_size, activation=hidden_activation)
logits = tf.layers.dense(x, output_size, activation=output_activation)
return distributions.Independent(distributions.Bernoulli(logits), 2)
def sample_decoder(self, n_samples):
prior = self.build_prior(self.latent_dim)
samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean()
return self.session.run([samples])
def restore_model(self):
print("Restoring")
self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta"))
self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir))
self._restored = True
want to run samples = vae.sample_decoder(5)
In my training routine, I run:
if self.checkpoint:
self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)
Based on the suggested answer below, I changed the restore method
self.saver = tf.train.Saver()
self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))
But now get a value error when it creates the Saver() object:
ValueError: No variables to save
The tf.train.import_meta_graph
restores the graph, meaning rebuilds the network architecture that was stored to the file. The call to tf.train.Saver.restore
on the other hand only restores the variable values from the file to the current graph in the session (this naturally fails if the some values of in the file belong to variables that do not exist in the currently active graph).
So if you already build the network layers in the code, you don't need to call tf.train.import_meta_graph
. Otherwise this might be causing you problems.
Not sure how the rest of your code looks like but here are some suggestions. First build the graph, then create the session, and finally restore if applicable. Your init might look like this then
def __init__(self, restore=True):
self.build_decoder = tf.make_template('decoder', self._build_decoder)
self.session = tf.Session()
if restore:
self.restore_model()
However if you are only restoring the encoder, and building the decoder anew, you might build the decoder last. But then don't forget to initialize its variables before usage.