Search code examples
pythontensorflowword2vec

How to do prediction using trained and stored tensorflow model


I have an existing trained model (specifically tensorflow word2vec https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/5_word2vec.ipynb). I restore the existing model well enough:

model1 = tf.train.import_meta_graph("models/model.meta")
model1.restore(sess, tf.train.latest_checkpoint("model/"))

But I don't know how to use the newly loaded (and trained) model to make predictions. How do I do predictions with a restored model?

Edit:

model code from the official tensorflow repo https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/word2vec/word2vec_basic.py


Solution

  • Based on how you are loading the checkpoint I assume this should be the best way to use it for inference.

    Load the placeholders:

    input = tf.get_default_graph().get_tensor_by_name("Placeholders/placeholder_name:0")
    ....
    

    Load the op you use to perform prediction:

    prediction = tf.get_default_graph().get_tensor_by_name("SomewhereInsideGraph/prediction_op_name:0")
    

    Create a session, execute the prediction op, and feed data in the placeholders.

    sess = tf.Session()
    sess.run(prediction, feed_dict={input:input_data})
    

    On the other hand, what I prefer to do is always create have the whole model creation inside a constructor of a class. Then, what I would do is the following:

    tf.reset_default_graph()
    model = ModelClass()
    loader = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    loader.restore(sess, path_to_checkpoint_dir)
    

    Since you want to load the embeddings from a trained word2vec model in another model, you should do something like:

    embeddings_new_model = tf.Variable(...,name="embeddings")
    embedding_saver = tf.train.Saver({"embeddings_word2vec": embeddings_new_model})
    with tf.Session() as sess:
        embedding_saver.restore(sess, "word2vec_model_path")
    

    Assuming that the embeddings variable in the word2vec model is named embeddings_word2vec.