Search code examples
tensorflowtensorboard

How to visualize TensorFlow graph without running train/evaluate with estimator API?


How can I visualize the graph on TensorBoard using the Estimator API of TensorFlow without running training or evaluation?

I know how it is achieved with the session API when you have access to the Graph object, but could not find anything for the Estimator API.


Solution

  • Estimators create and manage tf.Graph and tf.Session objects for you. These objects are therefore not easily accessible. Please note that, by default, the graph is exported inside the events file when you call estimator.train.

    What you can do however, is call your model_function outside of tf.estimator and then use the classic tf.summary.FileWriter() to export the graph.

    Here is a code snippet with a very simple estimator that just applies a dense layer to the input:

    import tensorflow as tf
    import numpy as np
    
    # Basic input_fn
    def input_fn(x, y, batch_size=4):
        dataset = tf.data.Dataset.from_tensor_slices((x, y))
        dataset = dataset.batch(batch_size).repeat(1)
        return dataset
    
    # Basic model_fn that just apply a dense layer to an input
    def model_fn(features, labels, mode):
        global_step = tf.train.get_or_create_global_step()
    
        y = tf.layers.dense(features, 1)
    
        increment_global_step = tf.assign_add(global_step, 1)
    
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={'preds':y},
                loss=tf.constant(0.0, tf.float32),
                train_op=increment_global_step)
    
    # Fake data
    x = np.random.normal(size=[10, 100])
    y = np.random.normal(size=[10])
    
    # Just to show that the estimator works
    estimator = tf.estimator.Estimator(model_fn=model_fn)
    estimator.train(input_fn=lambda: input_fn(x, y), steps=1)
    
    
    # Classic way of exporting the graph using placeholders and an outside call to the model_fn
    with tf.Graph().as_default() as g:
        # Placeholders
        features = tf.placeholder(tf.float32, x.shape)
        labels = tf.placeholder(tf.float32, y.shape)
    
        # Creates the graph
        _ = model_fn(features, labels, None)
    
        # Export the graph to ./graph
        with tf.Session() as sess:
            train_writer = tf.summary.FileWriter('./graph', sess.graph)