Search code examples
pythontensorflowmachine-learningkerastensorflow-estimator

Convert Estimator to TPUEstimator


Is it possible to convert an Estimator to a TPUEstimator in TensorFlow without significant effort in rewriting its functions? I have a model in Estimator form that works nicely on a CPU, but I don't know a convenient way to convert it to a TPUEstimator without having to rewrite the model_fn and input_fn.

The reason this requires significant work to do manually is that I am using Keras to create my model, and then the following helper function to create the Estimator:

   my_keras_model.compile(
                optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
                loss='categorical_crossentropy',
                metric='accuracy')
   estimator = tf.keras.estimator.model_to_estimator(keras_model=my_keras_model)

It would be great if I could do something like estimator.to_TPU_estimator() or something like that -- perhaps someone knows of a solution?


Solution

  • There can't be such a function, because model_fn specification is different in two estimators. Some differences are pretty deep, such as this one (from TPU tutorial):

    When training on a cloud TPU you must wrap the optimizer in a tf.contrib.tpu.CrossShardOptimizer, which uses an allreduce to aggregate gradients and broadcast the result to each shard (each TPU core).

    And it means patching the internals of keras optimizer and update ops.

    The recommended way to is to have different model_fn wrappers for GPU and TPU model and it seems the fastest way for you. In your case, it means rewriting keras model_to_estimator function for TPU estimator.


    The first and simplest approximation is this:

    def model_to_estimator(keras_model=None,
                           keras_model_path=None,
                           custom_objects=None,
                           model_dir=None,
                           config=None):
      keras_weights = keras_model.get_weights()
      keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
      est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
      _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
      return est
    

    Here, _save_first_checkpoint call is actually optional, but if you'd like to keep it, import this function from tensorflow.python.keras._impl.keras.estimator.


    The real work happens in _create_keras_tpu_model_fn function, which replaces _create_keras_model_fn. The changes are:

    • the internal tensorflow optimizer must be wrapped with CrossShardOptimizer as mentioned earlier, and

    • the inner function must return TPUEstimatorSpec.

    It is possible that few more lines must be patched as well, but it looks ok to me. A complete version is below:

    from tensorflow.python.keras._impl.keras.estimator import _save_first_checkpoint, _clone_and_build_model
    
    def model_to_estimator(keras_model=None,
                           keras_model_path=None,
                           custom_objects=None,
                           model_dir=None,
                           config=None):
      keras_weights = keras_model.get_weights()
      keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
      est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
      _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
      return est
    
    
    def _create_keras_tpu_model_fn(keras_model, custom_objects=None):
    
      def model_fn(features, labels, mode):
        """model_fn for keras Estimator."""
        model = _clone_and_build_model(mode, keras_model, custom_objects, features,
                                       labels)
        predictions = dict(zip(model.output_names, model.outputs))
    
        loss = None
        train_op = None
        eval_metric_ops = None
    
        # Set loss and metric only during train and evaluate.
        if mode is not tf.estimator.ModeKeys.PREDICT:
          model.optimizer.optimizer = tf.contrib.tpu.CrossShardOptimizer(model.optimizer.optimizer)
    
          model._make_train_function()  # pylint: disable=protected-access
          loss = model.total_loss
    
          if model.metrics:
            eval_metric_ops = {}
            # When each metric maps to an output
            if isinstance(model.metrics, dict):
              for i, output_name in enumerate(model.metrics.keys()):
                metric_name = model.metrics[output_name]
                if callable(metric_name):
                  metric_name = metric_name.__name__
                # When some outputs use the same metric
                if list(model.metrics.values()).count(metric_name) > 1:
                  metric_name += '_' + output_name
                eval_metric_ops[metric_name] = tf.metrics.mean(
                    model.metrics_tensors[i - len(model.metrics)])
            else:
              for i, metric_name in enumerate(model.metrics):
                if callable(metric_name):
                  metric_name = metric_name.__name__
                eval_metric_ops[metric_name] = tf.metrics.mean(
                    model.metrics_tensors[i])
    
        if mode is tf.estimator.ModeKeys.TRAIN:
          train_op = model.train_function.updates_op
    
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops)
    
      return model_fn