I'm using keras instead of dealing with tensorflow because its simplicity. But when I tried to visiualize the computational graph in keras by sending a keras.callbacks.Tensorboard
instance to the model.fit()
's callbacks
argument. The graph I got from tensorboard
is so awkward,
For demonstration purpose, here I only build a very simple linear classifier with 1 unit in 1 dense layer. But the graph looks like this:
Could I do the same thing as what we did in tensorflow, like use the name_space to group things together and give layers, bias, weights names? I mean, in the graph here, it's such a mess, I can only understand the Dense
layer, and a logistic loss
namespace. But typically with tensorflow, we can see something like train
namespace, and not so many nodes without namespace here. How can I make it more clear?
Tensorflow graph shows all the computations being called. You won't be able to simplify it.
As an alternative, Keras has it's own layer-by-layer graph. Which shows a clear and concise structure of your network. You can generate it by calling
from keras.utils import plot_model
plot_model(model, to_file='/some/pathname/model.png')
Last, you can also call model.summary()
, which generate a textual version of the graph, with additional summaries.
Here is an output of model.summary()
for example:
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 2048) 0
____________________________________________________________________________________________________
activation_1 (Activation) (None, 2048) 0
____________________________________________________________________________________________________
dense_1 (Dense) (None, 511) 1047039
____________________________________________________________________________________________________
activation_2 (Activation) (None, 511) 0
____________________________________________________________________________________________________
decoder_layer_1 (DecoderLayer) (None, 512) 0
____________________________________________________________________________________________________
ctg_output (OrLayer) (None, 201) 102912
____________________________________________________________________________________________________
att_output (OrLayer) (None, 312) 159744
====================================================================================================
Total params: 1,309,695.0
Trainable params: 1,309,695.0
Non-trainable params: 0.0