Search code examples
pythontensorflowkerasgraph-visualization

Keras Visualization of Model Built from Functional API


I wanted to ask if there was an easy way to visualize a Keras model built from the Functional API?

Right now, the best ways to debug at a high level a sequential model for me is:

model = Sequential()
model.add(...
...

print(model.summary())
SVG(model_to_dot(model).create(prog='dot', format='svg'))

However, I am having a hard time finding a good way to visualize the Keras API if we build a more complex, non-sequential model.


Solution

  • Yes there is, try checking the keras.utils which has a method plot_model() as explained on detail here. Seems that you already are familiar with keras.utils.vis_utils and the model_to_dot method, but this is another option. It's usage is something like:

    from keras.utils import plot_model
    plot_model(model, to_file='model.png')
    

    To be honest, that is the best I have managed to find using Keras only. Using model.summary() as you did is also useful sometimes. I also wished there were some tool to enable for better visualization of one's models, perhaps even to be able to see the weights per layers as to decide on optimal network structures and initializations (if you know about one please tell :] ).


    Probably the best option you currently have is to visualize things on Tensorboard, which you an include in Keras with the TensorBoard Callback. This enables you to visualize your training and the metrics of interest, as well as some info on activations of your layers,your biases and kernels, etc.. Basically you have to add this code to your program, before fitting your model:

    from keras.callbacks import TensorBoard
    #indicate folder to save, plus other options
    tensorboard = TensorBoard(log_dir='./logs/run1', histogram_freq=1,
        write_graph=True, write_images=False)  
    
    #save it in your callback list, where you can include other callbacks
    callbacks_list = [tensorboard]
    #then pass to fit as callback, remember to use validation_data also
    regressor.fit(X, Y, callbacks=callbacks_list, epochs=64, 
        validation_data=(X_test, Y_test), shuffle=True)
    

    You can then run Tensorboard (which runs locally on a webservice) with the following command on your terminal:

    tensorboard --logdir=/logs/run1
    

    This will then indicate you in which port to visualize your training. If you got different runs you can pass --logdir=/logs instead to be able to visualize them together for comparison. There are of course more options on the use of Tensorboard, so I suggest you check the included links if you are considering its use.