Search code examples
pythontensorflowkerasmo

Method to transfer weights between nested keras models


I'm trying to successively build up mixture models, iteratively adding sub-models.

I start by building and training a simple model. I then build a slightly more complex model that contains all of the original model but has more layers. I want to move the trained weights from the first model into the new model. How can I do this? The first model is nested in the second model.

Here's a dummy MWE:

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (concatenate, Conv1D,  Dense, LSTM)
from tensorflow.keras import Model, Input, backend

# data
x = np.random.normal(size = 100)
y = np.sin(x)+np.random.normal(size = 100)
# model 1
def make_model_1():
    inp = Input(1)
    l1 = Dense(5, activation = 'relu')(inp)
    out1 = Dense(1)(l1)
    model1 = Model(inp, out1)
    return model1

model1 = make_model_1()

model1.compile(optimizer = tf.keras.optimizers.SGD(),
               loss = tf.keras.losses.mean_squared_error)

model1.fit(x, y, epochs = 3, batch_size = 10)

# make model 2
def make_model_2():
    inp = Input(1)
    l1 = Dense(5, activation = 'relu')(inp)
    out1 = Dense(1)(l1)
    l2 = Dense(15, activation = 'sigmoid')(inp)
    out2 = Dense(1)(l2)
    bucket = tf.stack([out1, out2], axis=2)
    out = backend.squeeze(Dense(1)(bucket), axis = 2)
    model2 = Model(inp, out)
    return model2

model2 = make_model_2()

HOW CAN I TRANSFER THE WEIGHTS FROM model1 to model2? In a way that's automatic and completely agnostic about the nature of the two models, except that they are nested?


Solution

  • you can simply load the trained weights in the specific part of the new model you are interested in. I do this by creating a new instance of model1 into model2. After that, I load the trained weights.

    Here the full example

    # data
    x = np.random.normal(size = 100)
    y = np.sin(x)+np.random.normal(size = 100)
    
    # model 1
    def make_model_1():
        
        inp = Input(1)
        l1 = Dense(5, activation = 'relu')(inp)
        out1 = Dense(1)(l1)
        model1 = Model(inp, out1)
        
        return model1
    
    model1 = make_model_1()
    model1.compile(optimizer = tf.keras.optimizers.SGD(),
                   loss = tf.keras.losses.mean_squared_error)
    model1.fit(x, y, epochs = 3, batch_size = 10)
    
    # make model 2
    def make_model_2(trained_model):
        
        inp = Input(1)
    
        m = make_model_1()
        m.set_weights(trained_model.get_weights())
        out1 = m(inp)
        
        l2 = Dense(15, activation = 'sigmoid')(inp)
        out2 = Dense(1)(l2)
        bucket = tf.stack([out1, out2], axis=2)
        out = tf.keras.backend.squeeze(Dense(1)(bucket), axis = 2)
        model2 = Model(inp, out)
        
        return model2
    
    model2 = make_model_2(model1)
    model2.summary()