Search code examples
tensorflowkeraskeras-layertf.keras

Freeze sublayers in tensorflow 2


I have a model which is composed of custom layers. Each custom layer contains many tf.keras.layers. The problem is that if I want to freeze those layers after defining my model, the loop:

for i, layer in enumerate(model.layers):
    print(i, layer.name)

only prints the "outer" custom layers and not those who exist inside. Is there any way to access the inner layers so I can freeze them?

an example of a custom layer from the official tf docs:

class MLPBlock(layers.Layer):

  def __init__(self):
    super(MLPBlock, self).__init__()
    self.linear_1 = Linear(32)
    self.linear_2 = Linear(32)
    self.linear_3 = Linear(1)

  def call(self, inputs):
    x = self.linear_1(inputs)
    x = tf.nn.relu(x)
    x = self.linear_2(x)
    x = tf.nn.relu(x)
    return self.linear_3(x)

Solution

  • Ok i came up with a solution. An "update" function must be implemented inside the custom layer, which updates the inner layers so that they become non trainable. Here is a sample code:

    import tensorflow as tf
    import numpy as np
    
    layers = tf.keras.layers
    
    seq_model = tf.keras.models.Sequential
    
    
    class MDBlock(layers.Layer):
    
        def __init__(self):
            super(MDBlock, self).__init__()
            self.dense1 = layers.Dense(784, name="first")
            self.dense2 = layers.Dense(32, name="second")
            self.dense3 = layers.Dense(32, name="third")
            self.dense4 = layers.Dense(1, activation='sigmoid', name="outp")
    
        def call(self, inputs):
            x = self.dense1(inputs)
            x = tf.nn.relu(x)
            x = self.dense2(x)
            x = tf.nn.relu(x)
            x = self.dense3(x)
            x = tf.nn.relu(x)
            x = self.dense4(x)
            return x
    
        def updt(self):
            self.dense1.trainable = False
    
        def __str__(self):
            return "\nd1:{0}\nd2:{1}\nd3:{2}\nd4:{3}".format(self.dense1.trainable, self.dense2.trainable,
                                                             self.dense3.trainable, self.dense4.trainable)
    
    
    # define layer block
    layer = MDBlock()
    
    model = seq_model()
    model.add(layers.Input(shape=(784,)))
    model.add(layer)
    
    # Use updt function to make layers non-trainable
    for i, layer in enumerate(model.layers):
        layer.updt()
    
    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    
    # Generate dummy data
    data = np.random.random((1000, 784))
    labels = np.random.randint(2, size=(1000, 1))
    
    # Train the model, iterating on the data in batches of 32 samples
    model.fit(data, labels, epochs=10, batch_size=32)
    
    # print block's layers state
    for i, layer in enumerate(model.layers):
        print(i, layer)