Search code examples
pythontensorflowlogistic-regression

How to restore trained LinearClassifier from tensorflow high level API and make predictions


I have trained a logistic regression model model using tensorflow's LinearClassifier() class, and set the model_dir parameter, which specifies the location where to save metagrahps of checkpoints during model training:

# Create temporary directory where metagraphs will evenually be saved
model_dir = tempfile.mkdtemp()

logistic_model = tf.contrib.learn.LinearClassifier(
    feature_columns=feature_columns, 
    n_classes=num_labels, model_dir=model_dir)

I've been reading about restoring models from metagraphs, but have found nothing about how to do so for models created using the high level api. LinearClassifier() has a predict() function, but I can't find any documentation on how to run prediction using an instance of the model that has been restored via checkpoint metagraph. How would I go about doing this? Once the model is restored, my understanding is that I am working with a tf.Sess object, which lacks all of the built in functionality of the LinearClassifier class, like this:

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # Run prediction algorithm...

How do I run the same prediction algorithm used by the high-level api to make predictions on a restored model? Is there a better way to approach this?

Thanks for your input.


Solution

  • LinearClassifier() has the 'model_dir' param, if when points to a trained model will restore the model.
    During training, you do:

    logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir)
    classifier.fit(X_train, y_train, steps=10)
    

    During inference, LinearClassifier() will load the trained model from the path given, and you don't use the fit() method but call the predict() method:

    logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir)
    y_pred = classifier.predict(X_test)