Search code examples
pythontensorflowtensorboard

How to display Runtime Statistics in Tensorboard using Estimator API in a distributed environment


This article illustrates how to add Runtime statistics to Tensorboard:

    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    summary, _ = sess.run([merged, train_step],
                          feed_dict=feed_dict(True),
                          options=run_options,
                          run_metadata=run_metadata)
    train_writer.add_run_metadata(run_metadata, 'step%d' % i)
    train_writer.add_summary(summary, i)
    print('Adding run metadata for', i)

which creates the following details in Tensorboard:

Runtime Statistics in Tensorboard

This is fairly straightforward on a single machine. How could one do this in a distributed environment using Estimators?


Solution

  • You may use tf.train.ProfilerHook. However the catch is that it was released at 1.14.

    Example usage:

    estimator = tf.estimator.LinearClassifier(...)
    hooks = [tf.train.ProfilerHook(output_dir=model_dir, save_secs=600, show_memory=False)]
    estimator.train(input_fn=train_input_fn, hooks=hooks)
    

    Executing the hook will generate files timeline-xx.json in output_dir.

    Then open chrome://tracing/ in chrome browser and load the file. You will get a time usage timeline like below. enter image description here