Search code examples
tensorflowmachine-learningkerassubclassingdistributed-training

How to use model subclassing in Keras?


Having the following model written in the sequential API:

config = {
    'learning_rate': 0.001,
    'lstm_neurons':32,
    'lstm_activation':'tanh',
    'dropout_rate': 0.08,
    'batch_size': 128,
    'dense_layers':[
      {'neurons': 32, 'activation': 'relu'},
      {'neurons': 32, 'activation': 'relu'},
    ]
}

def get_model(num_features, output_size):
    opt = Adam(learning_rate=0.001)
    model = Sequential()
    model.add(Input(shape=[None,num_features], dtype=tf.float32, ragged=True))
    model.add(LSTM(config['lstm_neurons'], activation=config['lstm_activation']))
    model.add(BatchNormalization()) 
    if 'dropout_rate' in config:
      model.add(Dropout(config['dropout_rate']))

    for layer in config['dense_layers']:
      model.add(Dense(layer['neurons'], activation=layer['activation']))
      model.add(BatchNormalization()) 
      if 'dropout_rate' in layer:
        model.add(Dropout(layer['dropout_rate']))

    model.add(Dense(output_size, activation='sigmoid'))
    model.compile(loss='mse', optimizer=opt, metrics=['mse'])
    print(model.summary())
    return model

When using a distributed training framework, I need to convert the syntax to use model subclassing instead. I've looked at the docs but couldn't figure out how to do it.


Solution

  • Here is one equivalent subclassed implementation. Though I didn't test.

    import tensorflow as tf 
    
    # your config 
    config = {
        'learning_rate': 0.001,
        'lstm_neurons':32,
        'lstm_activation':'tanh',
        'dropout_rate': 0.08,
        'batch_size': 128,
        'dense_layers':[
          {'neurons': 32, 'activation': 'relu'},
          {'neurons': 32, 'activation': 'relu'},
        ]
    }
    
    # Subclassed API Model 
    class MySubClassed(tf.keras.Model):
        def __init__(self, output_size):
            super(MySubClassed, self).__init__()
            self.lstm = tf.keras.layers.LSTM(config['lstm_neurons'],
                                         activation=config['lstm_activation'])
            self.bn = tf.keras.layers.BatchNormalization()
            
            if 'dropout_rate' in config:
                self.dp1 = tf.keras.layers.Dropout(config['dropout_rate'])
                self.dp2 = tf.keras.layers.Dropout(config['dropout_rate'])
                self.dp3 = tf.keras.layers.Dropout(config['dropout_rate'])
    
            for layer in config['dense_layers']:
                self.dense1 = tf.keras.layers.Dense(layer['neurons'],
                                            activation=layer['activation'])
                self.bn1 = tf.keras.layers.BatchNormalization()
                self.dense2 = tf.keras.layers.Dense(layer['neurons'],
                                            activation=layer['activation'])
                self.bn2 = tf.keras.layers.BatchNormalization()
                
            self.out = tf.keras.layers.Dense(output_size, 
                                            activation='sigmoid')
                
        
        def call(self, inputs, training=True, **kwargs):
            x = self.lstm(inputs)
            x = self.bn(x)
            
            if 'dropout_rate' in config:
                x = self.dp1(x)
            
            x = self.dense1(x)
            x = self.bn1(x)
            if 'dropout_rate' in config:
                x = self.dp2(x)
    
            x = self.dense2(x)
            x = self.bn2(x)
            if 'dropout_rate' in config:
                x = self.dp3(x)
    
            return self.out(x)
    
        # A convenient way to get model summary 
        # and plot in subclassed api
        def build_graph(self, raw_shape):
            x = tf.keras.layers.Input(shape=(None, raw_shape), 
                                             ragged=True)
            return tf.keras.Model(inputs=[x], 
                                  outputs=self.call(x))
    

    Build and compile the mdoel

     s = MySubClassed(output_size=1)
     s.compile(
         loss = 'mse',
         metrics = ['mse'],
         optimizer = tf.keras.optimizers.Adam(learning_rate=0.001))
    

    Pass some tensor to create weights (check).

    raw_input = (16, 16, 16)
    y = s(tf.ones(shape=(raw_input))) 
    
    print("weights:", len(s.weights))
    print("trainable weights:", len(s.trainable_weights))
    
    weights: 21
    trainable weights: 15
    

    Summary and Plot

    Summarize and visualize the model graph.

    s.build_graph(16).summary()
    
    Model: "model"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_1 (InputLayer)         [(None, None, 16)]        0         
    _________________________________________________________________
    lstm (LSTM)                  (None, 32)                6272      
    _________________________________________________________________
    batch_normalization (BatchNo (None, 32)                128       
    _________________________________________________________________
    dropout (Dropout)            (None, 32)                0         
    _________________________________________________________________
    dense_2 (Dense)              (None, 32)                1056      
    _________________________________________________________________
    batch_normalization_3 (Batch (None, 32)                128       
    _________________________________________________________________
    dropout_1 (Dropout)          (None, 32)                0         
    _________________________________________________________________
    dense_3 (Dense)              (None, 32)                1056      
    _________________________________________________________________
    batch_normalization_4 (Batch (None, 32)                128       
    _________________________________________________________________
    dropout_2 (Dropout)          (None, 32)                0         
    _________________________________________________________________
    dense_4 (Dense)              (None, 1)                 33        
    =================================================================
    Total params: 8,801
    Trainable params: 8,609
    Non-trainable params: 192
    
    tf.keras.utils.plot_model(
        s.build_graph(16),
        show_shapes=True,
        show_dtype=True,
        show_layer_names=True,
        rankdir="TB",
    )
    

    enter image description here