Search code examples
python-3.xtensorflowtensorboardtensorflow2.0tf.keras

How to graph tf.keras model in Tensorflow-2.0?


I upgraded to Tensorflow 2.0 and there is no tf.summary.FileWriter("tf_graphs", sess.graph). I was looking through some other StackOverflow questions on this and they said to use tf.compat.v1.summary etc. Surely there must be a way to graph and visualize a tf.keras model in Tensorflow version 2. What is it? I'm looking for a tensorboard output like the one below. Thank you!

enter image description here


Solution

  • According to the docs, you can use Tensorboard to visualise graphs once your model has been trained.

    First, define your model and run it. Then, open Tensorboard and switch to the Graph tab.


    Minimal Compilable Example

    This example is taken from the docs. First, define your model and data.

    # Relevant imports.
    %load_ext tensorboard
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    from datetime import datetime
    from packaging import version
    
    import tensorflow as tf
    from tensorflow import keras
    
    # Define the model.
    model = keras.models.Sequential([
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
    
    (train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
    train_images = train_images / 255.0
    

    Next, train your model. Here, you will need to define a callback for Tensorboard to use for visualising stats and graphs.

    # Define the Keras TensorBoard callback.
    logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
    
    # Train the model.
    model.fit(
        train_images,
        train_labels, 
        batch_size=64,
        epochs=5, 
        callbacks=[tensorboard_callback])
    

    After training, in your notebook, run

    %tensorboard --logdir logs
    

    And switch to the Graph tab in the navbar:

    enter image description here

    You will see a graph that looks a lot like this:

    enter image description here