Search code examples
tensorflowkerastensorboard

How to better organize the nodes in tensorboard with keras?


I'm using keras instead of dealing with tensorflow because its simplicity. But when I tried to visiualize the computational graph in keras by sending a keras.callbacks.Tensorboard instance to the model.fit()'s callbacks argument. The graph I got from tensorboard is so awkward, For demonstration purpose, here I only build a very simple linear classifier with 1 unit in 1 dense layer. But the graph looks like this: enter image description here

Could I do the same thing as what we did in tensorflow, like use the name_space to group things together and give layers, bias, weights names? I mean, in the graph here, it's such a mess, I can only understand the Dense layer, and a logistic loss namespace. But typically with tensorflow, we can see something like train namespace, and not so many nodes without namespace here. How can I make it more clear?


Solution

  • Tensorflow graph shows all the computations being called. You won't be able to simplify it.

    As an alternative, Keras has it's own layer-by-layer graph. Which shows a clear and concise structure of your network. You can generate it by calling

    from keras.utils import plot_model
    plot_model(model, to_file='/some/pathname/model.png')
    

    Last, you can also call model.summary(), which generate a textual version of the graph, with additional summaries.

    Here is an output of model.summary() for example:

    Layer (type)                     Output Shape          Param #     Connected to                     
    ====================================================================================================
    input_1 (InputLayer)             (None, 2048)          0                                            
    ____________________________________________________________________________________________________
    activation_1 (Activation)        (None, 2048)          0                                            
    ____________________________________________________________________________________________________
    dense_1 (Dense)                  (None, 511)           1047039                                      
    ____________________________________________________________________________________________________
    activation_2 (Activation)        (None, 511)           0                                            
    ____________________________________________________________________________________________________
    decoder_layer_1 (DecoderLayer)   (None, 512)           0                                            
    ____________________________________________________________________________________________________
    ctg_output (OrLayer)             (None, 201)           102912                                       
    ____________________________________________________________________________________________________
    att_output (OrLayer)             (None, 312)           159744                                       
    ====================================================================================================
    Total params: 1,309,695.0
    Trainable params: 1,309,695.0
    Non-trainable params: 0.0