Search code examples
pythonserializationtensorflowkeras

Serializing and deserializing Tensorflow model in memory and continue training


I have seen variations of this question asked, but I haven't quite found a satisfactory answer yet. Basically, I would like to do the equivalent from keras model.to_json(), model.get_weights(), model.from_json(), model.set_weights() to tensorflow. I think I am getting close to there, but I am at a point where I am stuck. I'd prefer if I could get the weights and graph in the same string, but I understand if that isn't possible.

Currently, what I have is:

g = optimizer.minimize(loss_op,
                       global_step=tf.train.get_global_step())
de = g.graph.as_graph_def()
json_string = json_format.MessageToJson(de)

gd = tf.GraphDef()
gd = json_format.Parse(json_string, gd)

That seems to create the graph fine, but obviously the meta graph is not included for variable, weights, etc. There is also the meta graph, but the only thing I see is export_meta_graph, which doesn't seem to serialize in the same manner. I saw that MetaGraph has a proto function, but I don't know how to serialize those variables.

So in short, how would you take a tensorflow model (model as in weights, graph, etc), serialize it to a string (preferably json), then deserialize it and continue training or serve predictions.

Here are things that get me close to there and I have tried, but mostly has limitations in needing to write to disk, which I can't do in this case:

Gist on GitHub

This is the closest one I found, but the link to serializing a metagraph doesn't exist.


Solution

  • If you want the equivalent of keras Model.get_weights() and Model.set_weights(), these methods aren't strongly tied to keras internals and can be easily extracted.

    Original code

    Here's how they look like in keras source code:

    def get_weights(self):
      weights = []
      for layer in self.layers:
        weights += layer.weights
      return K.batch_get_value(weights)   # this is just `get_session().run(weights)`
    
    def set_weights(self, weights):
      tuples = []
      for layer in self.layers:
        num_param = len(layer.weights)
        layer_weights = weights[:num_param]
        for sw, w in zip(layer.weights, layer_weights):
          tuples.append((sw, w))
        weights = weights[num_param:]
      K.batch_set_value(tuples)  # another wrapper over `get_session().run(...)`
    

    Keras's weights is the list of numpy arrays (not json). As you can see, it uses the fact that model architecture is known (self.layers) which allows it to reconstruct the correct mapping from variables to values. Some seemingly non-trivial work is done in K.batch_set_value, but in fact it simply prepares assign ops and runs them in session.

    Getting and setting weights in pure tensorflow

    def tensorflow_get_weights():
      vars = tf.trainable_variables()
      values = tf.get_default_session().run(vars)
      return zip([var.name for var in vars], values)
    
    def tensorflow_set_weights(weights):
      assign_ops = []
      feed_dict = {}
      for var_name, value in weights:
        var = tf.get_default_session().graph.get_tensor_by_name(var_name)
        value = np.asarray(value)
        assign_placeholder = tf.placeholder(var.dtype, shape=value.shape)
        assign_op = tf.assign(var, assign_placeholder)
        assign_ops.append(assign_op)
        feed_dict[assign_placeholder] = value
      tf.get_default_session().run(assign_ops, feed_dict=feed_dict)
    

    Here I assume that you want to serialize / deserialize the whole model (i.e., all trainable variables) and in the default session. If this is not the case, functions above are easily customizable.

    Testing

    x = tf.placeholder(shape=[None, 5], dtype=tf.float32, name='x')
    W = tf.Variable(np.zeros([5, 5]), dtype=tf.float32, name='W')
    b = tf.Variable(np.zeros([5]), dtype=tf.float32, name='b')
    y = tf.add(tf.matmul(x, W), b)
    
    with tf.Session() as session:
      session.run(tf.global_variables_initializer())
    
      # Save the weights
      w = tensorflow_get_weights()
      print(W.eval(), b.eval())
    
      # Update the model
      session.run([tf.assign(W, np.ones([5, 5])), tf.assign(b, np.ones([5]) * 2)])
      print(W.eval(), b.eval())
    
      # Restore the weights
      tensorflow_set_weights(w)
      print(W.eval(), b.eval())
    

    If you run this test, you should see the model was freezed at zeros, then got updated and then restored back to zeros.