I have seen variations of this question asked, but I haven't quite found a satisfactory answer yet. Basically, I would like to do the equivalent from keras model.to_json()
, model.get_weights()
, model.from_json()
, model.set_weights()
to tensorflow. I think I am getting close to there, but I am at a point where I am stuck. I'd prefer if I could get the weights and graph in the same string, but I understand if that isn't possible.
Currently, what I have is:
g = optimizer.minimize(loss_op,
global_step=tf.train.get_global_step())
de = g.graph.as_graph_def()
json_string = json_format.MessageToJson(de)
gd = tf.GraphDef()
gd = json_format.Parse(json_string, gd)
That seems to create the graph fine, but obviously the meta graph is not included for variable, weights, etc. There is also the meta graph, but the only thing I see is export_meta_graph, which doesn't seem to serialize in the same manner. I saw that MetaGraph has a proto function, but I don't know how to serialize those variables.
So in short, how would you take a tensorflow model (model as in weights, graph, etc), serialize it to a string (preferably json), then deserialize it and continue training or serve predictions.
Here are things that get me close to there and I have tried, but mostly has limitations in needing to write to disk, which I can't do in this case:
This is the closest one I found, but the link to serializing a metagraph doesn't exist.
If you want the equivalent of keras Model.get_weights()
and Model.set_weights()
, these methods aren't strongly tied to keras internals and can be easily extracted.
Here's how they look like in keras source code:
def get_weights(self):
weights = []
for layer in self.layers:
weights += layer.weights
return K.batch_get_value(weights) # this is just `get_session().run(weights)`
def set_weights(self, weights):
tuples = []
for layer in self.layers:
num_param = len(layer.weights)
layer_weights = weights[:num_param]
for sw, w in zip(layer.weights, layer_weights):
tuples.append((sw, w))
weights = weights[num_param:]
K.batch_set_value(tuples) # another wrapper over `get_session().run(...)`
Keras's weights
is the list of numpy arrays (not json). As you can see, it uses the fact that model architecture is known (self.layers
) which allows it to reconstruct the correct mapping from variables to values. Some seemingly non-trivial work is done in K.batch_set_value
, but in fact it simply prepares assign ops and runs them in session.
def tensorflow_get_weights():
vars = tf.trainable_variables()
values = tf.get_default_session().run(vars)
return zip([var.name for var in vars], values)
def tensorflow_set_weights(weights):
assign_ops = []
feed_dict = {}
for var_name, value in weights:
var = tf.get_default_session().graph.get_tensor_by_name(var_name)
value = np.asarray(value)
assign_placeholder = tf.placeholder(var.dtype, shape=value.shape)
assign_op = tf.assign(var, assign_placeholder)
assign_ops.append(assign_op)
feed_dict[assign_placeholder] = value
tf.get_default_session().run(assign_ops, feed_dict=feed_dict)
Here I assume that you want to serialize / deserialize the whole model (i.e., all trainable variables) and in the default session. If this is not the case, functions above are easily customizable.
x = tf.placeholder(shape=[None, 5], dtype=tf.float32, name='x')
W = tf.Variable(np.zeros([5, 5]), dtype=tf.float32, name='W')
b = tf.Variable(np.zeros([5]), dtype=tf.float32, name='b')
y = tf.add(tf.matmul(x, W), b)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
# Save the weights
w = tensorflow_get_weights()
print(W.eval(), b.eval())
# Update the model
session.run([tf.assign(W, np.ones([5, 5])), tf.assign(b, np.ones([5]) * 2)])
print(W.eval(), b.eval())
# Restore the weights
tensorflow_set_weights(w)
print(W.eval(), b.eval())
If you run this test, you should see the model was freezed at zeros, then got updated and then restored back to zeros.