How can I visualize the graph on TensorBoard using the Estimator API of TensorFlow without running training or evaluation?
I know how it is achieved with the session API when you have access to the Graph object, but could not find anything for the Estimator API.
Estimators create and manage tf.Graph
and tf.Session
objects for you. These objects are therefore not easily accessible. Please note that, by default, the graph is exported inside the events file when you call estimator.train
.
What you can do however, is call your model_function
outside of tf.estimator
and then use the classic tf.summary.FileWriter()
to export the graph.
Here is a code snippet with a very simple estimator that just applies a dense layer to the input:
import tensorflow as tf
import numpy as np
# Basic input_fn
def input_fn(x, y, batch_size=4):
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.batch(batch_size).repeat(1)
return dataset
# Basic model_fn that just apply a dense layer to an input
def model_fn(features, labels, mode):
global_step = tf.train.get_or_create_global_step()
y = tf.layers.dense(features, 1)
increment_global_step = tf.assign_add(global_step, 1)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions={'preds':y},
loss=tf.constant(0.0, tf.float32),
train_op=increment_global_step)
# Fake data
x = np.random.normal(size=[10, 100])
y = np.random.normal(size=[10])
# Just to show that the estimator works
estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=lambda: input_fn(x, y), steps=1)
# Classic way of exporting the graph using placeholders and an outside call to the model_fn
with tf.Graph().as_default() as g:
# Placeholders
features = tf.placeholder(tf.float32, x.shape)
labels = tf.placeholder(tf.float32, y.shape)
# Creates the graph
_ = model_fn(features, labels, None)
# Export the graph to ./graph
with tf.Session() as sess:
train_writer = tf.summary.FileWriter('./graph', sess.graph)