Search code examples
pythontensorflowplotkerasdeep-learning

How do I plot a Keras/Tensorflow subclassing API model?


I made a model that runs correctly using the Keras Subclassing API. The model.summary() also works correctly. When trying to use tf.keras.utils.plot_model() to visualize my model's architecture, it will just output this image:

enter image description here

This almost feels like a joke from the Keras development team. This is the full architecture:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_diabetes
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from tensorflow.keras.layers import Dense, GaussianDropout, GRU, Concatenate, Reshape
from tensorflow.keras.models import Model

X, y = load_diabetes(return_X_y=True)

data = tf.data.Dataset.from_tensor_slices((X, y)).\
    shuffle(len(X)).\
    map(lambda x, y: (tf.divide(x, tf.reduce_max(x)), y))

training = data.take(400).batch(8)
testing = data.skip(400).map(lambda x, y: (tf.expand_dims(x, 0), y))

class NeuralNetwork(Model):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.dense1 = Dense(16, input_shape=(10,), activation='relu', name='Dense1')
        self.dense2 = Dense(32, activation='relu', name='Dense2')
        self.resha1 = Reshape((1, 32))
        self.gru1 = GRU(16, activation='tanh', recurrent_dropout=1e-1)
        self.dense3 = Dense(64, activation='relu', name='Dense3')
        self.gauss1 = GaussianDropout(5e-1)
        self.conca1 = Concatenate()
        self.dense4 = Dense(128, activation='relu', name='Dense4')
        self.dense5 = Dense(1, name='Dense5')

    def call(self, x, *args, **kwargs):
        x = self.dense1(x)
        x = self.dense2(x)
        a = self.resha1(x)
        a = self.gru1(a)
        b = self.dense3(x)
        b = self.gauss1(b)
        x = self.conca1([a, b])
        x = self.dense4(x)
        x = self.dense5(x)
        return x


skynet = NeuralNetwork()
skynet.build(input_shape=(None, 10))
skynet.summary()

model = tf.keras.utils.plot_model(model=skynet,
         show_shapes=True, to_file='/home/nicolas/Desktop/model.png')

Solution

  • I've found some workaround to plot with the model sub-classing API. For the obvious reason Sub-Classing API doesn't support Sequential or Functional API like model.summary() and nice visualization using plot_model. Here, I will demonstrate both.

    class my_model(keras.Model):
        def __init__(self, dim):
            super(my_model, self).__init__()
            self.Base  = keras.keras.applications.VGG16(
                input_shape=(dim), 
                include_top = False, 
                weights = 'imagenet'
            )
            self.GAP   = L.GlobalAveragePooling2D()
            self.BAT   = L.BatchNormalization()
            self.DROP  = L.Dropout(rate=0.1)
            self.DENS  = L.Dense(256, activation='relu', name = 'dense_A')
            self.OUT   = L.Dense(1, activation='sigmoid')
        
        def call(self, inputs):
            x  = self.Base(inputs)
            g  = self.GAP(x)
            b  = self.BAT(g)
            d  = self.DROP(b)
            d  = self.DENS(d)
            return self.OUT(d)
        
        # AFAIK: The most convenient method to print model.summary() 
        # similar to the sequential or functional API like.
        def build_graph(self):
            x = Input(shape=(dim))
            return Model(inputs=[x], outputs=self.call(x))
    
    dim = (124,124,3)
    model = my_model((dim))
    model.build((None, *dim))
    model.build_graph().summary()
    

    It will produce as follows:

    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_67 (InputLayer)        [(None, 124, 124, 3)]     0         
    _________________________________________________________________
    vgg16 (Functional)           (None, 3, 3, 512)         14714688  
    _________________________________________________________________
    global_average_pooling2d_32  (None, 512)               0         
    _________________________________________________________________
    batch_normalization_7 (Batch (None, 512)               2048      
    _________________________________________________________________
    dropout_5 (Dropout)          (None, 512)               0         
    _________________________________________________________________
    dense_A (Dense)              (None, 256)               402192    
    _________________________________________________________________
    dense_7 (Dense)              (None, 1)                 785       
    =================================================================
    Total params: 14,848,321
    Trainable params: 14,847,297
    Non-trainable params: 1,024
    

    Now by using the build_graph function, we can simply plot the whole architecture.

    # Just showing all possible argument for newcomer.  
    tf.keras.utils.plot_model(
        model.build_graph(),                      # here is the trick (for now)
        to_file='model.png', dpi=96,              # saving  
        show_shapes=True, show_layer_names=True,  # show shapes and layer name
        expand_nested=False                       # will show nested block
    )
    

    It will produce as follows: -)

    a


    Similar QnA:

    1. Retrieving Keras Layer Properties from a tf.keras.Model
    2. Visualize nested keras.Model (SubClassed API) GAN model