I wonder how can I export the estimator and then import it for prediction from MNIST tutorial, Tensorflow's page. Thank you!
The Estimator
has model_dir
args where the model will be saved. So during prediction we use the Estimator
and call the predict
method which recreates the graph and the checkpoints are loaded.
For the MNIST
example, the prediction code would be:
tf.reset_default_graph()
# An input-function to predict the class of new data.
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": eval_data},
num_epochs=1,
shuffle=False)
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")
#Prediction call
predictions = mnist_classifier.predict(input_fn=predict_input_fn)
pred_class = np.array([p['classes'] for p in predictions]).squeeze()
print(pred_class)
# Output
# [7 2 1 ... 4 5 6]