Search code examples
tensorflowjupytergraph-visualizationtensorboard

Simple way to visualize a TensorFlow graph in Jupyter?


The official way to visualize a TensorFlow graph is with TensorBoard, but sometimes I just want a quick look at the graph when I'm working in Jupyter.

Is there a quick solution, ideally based on TensorFlow tools, or standard SciPy packages (like matplotlib), but if necessary based on 3rd party libraries?


Solution

  • TensorFlow 2.0 now supportsTensorBoardinJupytervia magic commands (e.g %tensorboard --logdir logs/train). Here's a link to tutorials and examples.

    [EDITS 1, 2]

    As @MiniQuark mentioned in a comment, we need to load the extension first(%load_ext tensorboard.notebook).

    Below are usage examples for using graph mode, @tf.function and tf.keras (in tensorflow==2.0.0-alpha0):

    1. Example using graph mode in TF2 (via tf.compat.v1.disable_eager_execution())

    %load_ext tensorboard.notebook
    import tensorflow as tf
    tf.compat.v1.disable_eager_execution()
    from tensorflow.python.ops.array_ops import placeholder
    from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
    from tensorflow.python.summary.writer.writer import FileWriter
    
    with tf.name_scope('inputs'):
       x = placeholder(tf.float32, shape=[None, 2], name='x')
       y = placeholder(tf.int32, shape=[None], name='y')
    
    with tf.name_scope('logits'):
       layer = tf.keras.layers.Dense(units=2)
       logits = layer(x)
    
    with tf.name_scope('loss'):
       xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
       loss_op = tf.reduce_mean(xentropy)
    
    with tf.name_scope('optimizer'):
       optimizer = GradientDescentOptimizer(0.01)
       train_op = optimizer.minimize(loss_op)
    
    FileWriter('logs/train', graph=train_op.graph).close()
    %tensorboard --logdir logs/train
    

    2. Same example as above but now using @tf.function decorator for forward-backward passes and without disabling eager execution:

    %load_ext tensorboard.notebook
    import tensorflow as tf
    import numpy as np
    
    logdir = 'logs/'
    writer = tf.summary.create_file_writer(logdir)
    tf.summary.trace_on(graph=True, profiler=True)
    
    @tf.function
    def forward_and_backward(x, y, w, b, lr=tf.constant(0.01)):
    
        with tf.name_scope('logits'):
            logits = tf.matmul(x, w) + b
        
        with tf.name_scope('loss'):
            loss_fn = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=y, logits=logits)
            reduced = tf.reduce_sum(loss_fn)
            
        with tf.name_scope('optimizer'):
            grads = tf.gradients(reduced, [w, b])
            _ = [x.assign(x - g*lr) for g, x in zip(grads, [w, b])]
        return reduced
    
    # inputs
    x = tf.convert_to_tensor(np.ones([1, 2]), dtype=tf.float32)
    y = tf.convert_to_tensor(np.array([1]))
    # params
    w = tf.Variable(tf.random.normal([2, 2]), dtype=tf.float32)
    b = tf.Variable(tf.zeros([1, 2]), dtype=tf.float32)
    
    loss_val = forward_and_backward(x, y, w, b)
    
    with writer.as_default():
        tf.summary.trace_export(
            name='NN',
            step=0,
            profiler_outdir=logdir)
    
    %tensorboard --logdir logs/
    

    3. Using tf.keras API:

    %load_ext tensorboard.notebook
    import tensorflow as tf
    import numpy as np
    x_train = [np.ones((1, 2))]
    y_train = [np.ones(1)]
    
    model = tf.keras.models.Sequential([tf.keras.layers.Dense(2, input_shape=(2, ))])
                                        
    model.compile(
        optimizer='sgd',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
    
    logdir = "logs/"
    
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
    
    model.fit(x_train,
              y_train,
              batch_size=1,
              epochs=1,
              callbacks=[tensorboard_callback])
    
    %tensorboard --logdir logs/
    

    These examples will produce something like this below the cell:

    enter image description here