Search code examples
pythontensorflowtensorflow-estimatorseq2seq

Reusing Embedding Variable For Inference in the Tf.Estimator API


In NMT using seq2seq architecture, during inference, we need the embedding variable trained during the training phase as an input to the GreedyEmbeddingHelper or the BeamSearchDecoder.

The question is, within the context of training and inferring using the Estimator API, how can we extract this trained embedding variable to be used for prediction?


Solution

  • I figured out a solution based on the following stackoverflow answer. For the prediction phase, you can use the tf.contrib.framework.load_variable to retrieve the embedding variable from a trained and saved Tensorflow model as follows:

    if mode == tf.estimator.ModeKeys.PREDICT:
        embeddings = tf.constant(tf.contrib.framework.load_variable('.','embed/embeddings'))
        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embeddings,
        start_tokens=tf.fill([batch_size], 1),end_token=0)
    

    So in my case, I was running the code from the same folder containing the saved model, and my variable name was 'embed/embedding'. Note that this only works with embeddings trained via a tensorflow model. Otherwise, refer to the answer linked above.

    To find the variable name using the estimator API, you can use the method get_variable_names() to get a list of all the variable names saved in the graph.