Search code examples
pythontensorflowtensorflow-datasetstensorflow-estimator

Tensorflow Estimator.predict_scores not yielding the correct number of predictions when using the Dataset API in the input function


I am using tensorflow 1.5 and I am puzzled by this strange behavior that I can't explain.
I produced a minimal example:

import tensorflow as tf
import numpy as np


def input_function(x, y, batch_size=128, shuffle=True, n_epochs=None):
    data_set = tf.data.Dataset.from_tensor_slices({"x": x, "y": y})
    if shuffle:
        data_set = data_set.shuffle(buffer_size=1024, seed=None, reshuffle_each_iteration=True)
    data_set = data_set.batch(batch_size)
    data_set = data_set.repeat(n_epochs)
    iterator = data_set.make_one_shot_iterator()
    example = iterator.get_next()
    return {"features": example["x"]}, example["y"]


def main():
    n_samples = 256
    n_features = 16
    n_labels = 1

    x = np.random.rand(n_samples, n_features).astype(np.float32)
    y = np.random.rand(n_samples, n_labels).astype(np.float32)

    feature_column = tf.contrib.layers.real_valued_column(column_name='features', dimension=n_features)
    estimator = tf.contrib.learn.DNNRegressor([10], [feature_column], optimizer=tf.train.AdamOptimizer())

    estimator.fit(input_fn=lambda: input_function(x, y, batch_size=128, shuffle=True, n_epochs=32))
    pred = estimator.predict_scores(input_fn=lambda: input_function(x, y, batch_size=16, shuffle=False, n_epochs=1))
    print("len(pred) = {} (should be {})".format(len(list(pred)), n_samples))


if __name__ == '__main__':
    main()

In this example, the call to 'fit' seems to be working fine (I'm not sure though) but the call to 'predict_scores' only produces batch_size (=16) predictions instead of n_samples (=256). What I am doing wrong ?
This problem disappears if I use the tf.esimator.inputs.numpy_input_fn although eventually I'll have to use an input function that uses a TFRecordDataset to read a large amount of training data from tfrecord files, similarly to what is showed here: https://www.tensorflow.org/programmers_guide/datasets#using_high-level_apis
Any help would be really appreciated.


Solution

  • This is a bug in the tf.contrib.learn.Estimator class, which incorrectly assumes that the input is constant, and only reads one batch, instead of running the input function multiple times to get all of the data. The tf.contrib.learn.Estimator and tf.contrib.learn.DNNRegressor classes are deprecated and slated for removal, so it is unlikely that they will be fixed.

    However, the tf.estimator.DNNRegressor class has been fixed to work with tf.data, and you can modify your code to use it as follows:

    def main():
        n_samples = 256
        n_features = 16
        n_labels = 1
    
        x = np.random.rand(n_samples, n_features).astype(np.float32)
        y = np.random.rand(n_samples, n_labels).astype(np.float32)
    
        feature_column = tf.contrib.layers.real_valued_column(
            column_name='features', dimension=n_features)
    
        # Use the `tf.estimator.DNNRegressor` constructor instead of
        # `tf.contrib.learn.DNNRegressor`.
        estimator = tf.estimator.DNNRegressor(
            [10], [feature_column], optimizer=tf.train.AdamOptimizer())
    
        # Replace `estimator.fit()` with `estimator.train()`.
        estimator.train(input_fn=lambda: input_function(
            x, y, batch_size=128, shuffle=True, n_epochs=32))
    
        # Replace `estimator.predict_scores()` with `estimator.predict()`.
        pred = estimator.predict(input_fn=lambda: input_function(
            x, y, batch_size=16, shuffle=False, n_epochs=1))
    
        print("len(pred) = {} (should be {})".format(len(list(pred)), n_samples))