Search code examples
pythontensorflowmnist

Tensorflow export estimators for prediction


I wonder how can I export the estimator and then import it for prediction from MNIST tutorial, Tensorflow's page. Thank you!


Solution

  • The Estimator has model_dir args where the model will be saved. So during prediction we use the Estimator and call the predict method which recreates the graph and the checkpoints are loaded.

    For the MNIST example, the prediction code would be:

    tf.reset_default_graph()
    
    # An input-function to predict the class of new data.
    predict_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": eval_data},
        num_epochs=1,
        shuffle=False)
    
    mnist_classifier = tf.estimator.Estimator(
          model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")
    
    #Prediction call
    predictions = mnist_classifier.predict(input_fn=predict_input_fn)
    
    pred_class = np.array([p['classes'] for p in predictions]).squeeze()
    print(pred_class)
    
    # Output
    # [7 2 1 ... 4 5 6]