Search code examples
pythonmachine-learningtensorflowtflearn

Example of tensorflow.contrib.learn.ExportStrategy


Can someone provide examples of full working code for Tensorflow

tf.contrib.learn.ExportStrategy

The documentation lacks examples. I also could not find any examples on Github or Stackoverflow for this seemingly obscure Tensorflow operation.

Documentation: https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ExportStrategy


Solution

  • Google CloudML has a good working example here: https://github.com/GoogleCloudPlatform/cloudml-samples/tree/master/census/customestimator/trainer

    You'll need to use their full code to get the example to work, but here is the gist of how to use ExportStrategy:

    import tensorflow as tf
    from tensorflow.contrib.learn.python.learn import learn_runner
    from tensorflow.contrib.learn.python.learn.utils import (
        saved_model_export_utils)
    from tensorflow.contrib.training.python.training import hparam
    
    def csv_serving_input_fn():
        """Build the serving inputs."""
        csv_row = tf.placeholder(
            shape=[None],
            dtype=tf.string
        )
        features = parse_csv(csv_row)
        # Ignore label column
        features.pop(LABEL_COLUMN)
        return tf.estimator.export.ServingInputReceiver(
            features, {'csv_row': csv_row})
    
    
    def example_serving_input_fn():
        """Build the serving inputs."""
        example_bytestring = tf.placeholder(
            shape=[None],
            dtype=tf.string,
        )
        features = tf.parse_example(
            example_bytestring,
            tf.feature_column.make_parse_example_spec(INPUT_COLUMNS)
        )
        return tf.estimator.export.ServingInputReceiver(
            features, {'example_proto': example_bytestring})
    
    
    def json_serving_input_fn():
      """Build the serving inputs."""
      inputs = {}
      for feat in INPUT_COLUMNS:
        inputs[feat.name] = tf.placeholder(shape=[None], dtype=feat.dtype)
      return tf.estimator.export.ServingInputReceiver(inputs, inputs)
    
    
    SERVING_FUNCTIONS = {
        'JSON': json_serving_input_fn,
        'EXAMPLE': example_serving_input_fn,
        'CSV': csv_serving_input_fn
    }
    
    # Run the training job
    # learn_runner pulls configuration information from environment
    # variables using tf.learn.RunConfig and uses this configuration
    # to conditionally execute Experiment, or param server code
    learn_runner.run(
      generate_experiment_fn(
          min_eval_frequency=args.min_eval_frequency,
          eval_delay_secs=args.eval_delay_secs,
          train_steps=args.train_steps,
          eval_steps=args.eval_steps,
          export_strategies=[saved_model_export_utils.make_export_strategy(
              SERVING_FUNCTIONS[args.export_format],
              exports_to_keep=1
          )]
      ),
      run_config=tf.contrib.learn.RunConfig(model_dir=args.job_dir),
      hparams=hparam.HParams(**args.__dict__)
    )