Search code examples
machine-learningnlpbert-language-modeltransfer-learning

Is it possible to freeze some params in a single layer of TFBertModel


I am trying to utilize the pretrained Bert model of tensorflow which has approx 110 million params and it is near impossible to train these params using my gpu. And freezing the entire layer makes all these params untrainable.

Is it possible to make the layer partially trainable? Like have a couple million params trainable and the rest untrainable?

input_ids_layer = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name='input_ids')

input_attention_layer = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name='attention_mask')

model = TFAutoModel.from_pretrained("bert-base-uncased")

for layer in model.layers:

    for i in range(len(layer.weights)):
//assuming there are 199 weights
        if i>150:
            layer.weights[i]._trainable = True
        else:
            layer.weights[i]._trainable = False

Solution

  • I don't know about training some weights inside a layers, but I still suggest you to do the "standard way": freezing the layers is what is usually done in these cases to avoid retraining everything. However, you must not freeze all the layers, since it would be useless. What you want to do is to freeze everything except the last few layers, and then train the neural network.

    This works since the first layers usually learn very abstract features, and therefore are transferrable across many problems. On the other hand, the last layers usually learn the features that really solves the task at hand, based on the current dataset.

    Therefore, if you want to re-train a pretrained model in another dataset, you just need to retrain the last few layers. You can also edit the last layers of the neural network by adding some Dense layers and changing the output of the last layer, which is useful if for example the number of classes to predict is different w.r.t the original dataset. There are a lot of short and easy tutorials that you can follow online to do that.

    To summarize:

    1. Freeze all the layers expect the last one
    2. (optional) Create new layers and link them with the output of the second-last layer
    3. Train the network