Is it possible to convert an Estimator
to a TPUEstimator
in TensorFlow without significant effort in rewriting its functions? I have a model in Estimator
form that works nicely on a CPU, but I don't know a convenient way to convert it to a TPUEstimator
without having to rewrite the model_fn
and input_fn
The reason this requires significant work to do manually is that I am using Keras to create my model, and then the following helper function to create the Estimator
optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
estimator = tf.keras.estimator.model_to_estimator(keras_model=my_keras_model)
It would be great if I could do something like estimator.to_TPU_estimator()
or something like that -- perhaps someone knows of a solution?
There can't be such a function, because model_fn
specification is different in two estimators. Some differences are pretty deep, such as this one (from TPU tutorial):
When training on a cloud TPU you must wrap the optimizer in a
, which uses anallreduce
to aggregate gradients and broadcast the result to each shard (each TPU core).
And it means patching the internals of keras optimizer and update ops.
The recommended way to is to have different model_fn
wrappers for GPU and TPU model and it seems the fastest way for you. In your case, it means rewriting keras model_to_estimator
function for TPU estimator.
The first and simplest approximation is this:
def model_to_estimator(keras_model=None,
keras_weights = keras_model.get_weights()
keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
_save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
return est
Here, _save_first_checkpoint
call is actually optional, but if you'd like to keep it, import this function from tensorflow.python.keras._impl.keras.estimator
The real work happens in _create_keras_tpu_model_fn
function, which replaces _create_keras_model_fn
. The changes are:
the internal tensorflow optimizer must be wrapped with CrossShardOptimizer
as mentioned earlier, and
the inner function must return TPUEstimatorSpec
It is possible that few more lines must be patched as well, but it looks ok to me. A complete version is below:
from tensorflow.python.keras._impl.keras.estimator import _save_first_checkpoint, _clone_and_build_model
def model_to_estimator(keras_model=None,
keras_weights = keras_model.get_weights()
keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
_save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
return est
def _create_keras_tpu_model_fn(keras_model, custom_objects=None):
def model_fn(features, labels, mode):
"""model_fn for keras Estimator."""
model = _clone_and_build_model(mode, keras_model, custom_objects, features,
predictions = dict(zip(model.output_names, model.outputs))
loss = None
train_op = None
eval_metric_ops = None
# Set loss and metric only during train and evaluate.
if mode is not tf.estimator.ModeKeys.PREDICT:
model.optimizer.optimizer = tf.contrib.tpu.CrossShardOptimizer(model.optimizer.optimizer)
model._make_train_function() # pylint: disable=protected-access
loss = model.total_loss
if model.metrics:
eval_metric_ops = {}
# When each metric maps to an output
if isinstance(model.metrics, dict):
for i, output_name in enumerate(model.metrics.keys()):
metric_name = model.metrics[output_name]
if callable(metric_name):
metric_name = metric_name.__name__
# When some outputs use the same metric
if list(model.metrics.values()).count(metric_name) > 1:
metric_name += '_' + output_name
eval_metric_ops[metric_name] = tf.metrics.mean(
model.metrics_tensors[i - len(model.metrics)])
for i, metric_name in enumerate(model.metrics):
if callable(metric_name):
metric_name = metric_name.__name__
eval_metric_ops[metric_name] = tf.metrics.mean(
if mode is tf.estimator.ModeKeys.TRAIN:
train_op = model.train_function.updates_op
return tf.contrib.tpu.TPUEstimatorSpec(
return model_fn