Search code examples
pythontensorflowmachine-learningtensorflow-probability

Saving and restoring functions in TensorFlow


I am working on a VAE project in TensorFlow where the encoder/decoder networks are build in functions. The idea is to be able to save, then load the trained model and do sampling, using the encoder function.

After restoring the model, I am having trouble getting the decoder function to run and give me back the restored, trained variables, getting an "Uninitialized value" error. I assume it is because the function is either creating a new new one, overwriting the existing, or otherwise. But I cannot figure out how to solve this. Here is some code:

class VAE(object):    
    def __init__(self, restore=True):
        self.session = tf.Session()
        if restore:
            self.restore_model()
            self.build_decoder = tf.make_template('decoder', self._build_decoder)

@staticmethod
def _build_decoder(z, output_size=768, hidden_size=200,
                  hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid):
    x = tf.layers.dense(z, hidden_size, activation=hidden_activation)
    x = tf.layers.dense(x, hidden_size, activation=hidden_activation)
    logits = tf.layers.dense(x, output_size, activation=output_activation)
    return distributions.Independent(distributions.Bernoulli(logits), 2)

def sample_decoder(self, n_samples):
    prior = self.build_prior(self.latent_dim)
    samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean()
    return self.session.run([samples])

def restore_model(self):
    print("Restoring")
    self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta"))
    self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir))
    self._restored = True

want to run samples = vae.sample_decoder(5)

In my training routine, I run:

        if self.checkpoint:
            self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)

UPDATE

Based on the suggested answer below, I changed the restore method

self.saver = tf.train.Saver()
self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))

But now get a value error when it creates the Saver() object:

ValueError: No variables to save

Solution

  • The tf.train.import_meta_graph restores the graph, meaning rebuilds the network architecture that was stored to the file. The call to tf.train.Saver.restore on the other hand only restores the variable values from the file to the current graph in the session (this naturally fails if the some values of in the file belong to variables that do not exist in the currently active graph).

    So if you already build the network layers in the code, you don't need to call tf.train.import_meta_graph. Otherwise this might be causing you problems.

    Not sure how the rest of your code looks like but here are some suggestions. First build the graph, then create the session, and finally restore if applicable. Your init might look like this then

    def __init__(self, restore=True):
        self.build_decoder = tf.make_template('decoder', self._build_decoder)
        self.session = tf.Session()
        if restore:
            self.restore_model()
    

    However if you are only restoring the encoder, and building the decoder anew, you might build the decoder last. But then don't forget to initialize its variables before usage.