Search code examples
pythontensorflowkerastensorflow2

How to revert keras sequential model to a state previous to build?


I want to reuse the same model architecture but with different datasets, that is, to programmatically change the input layer to a different shape, and reset model parameters if needed.

Something along the lines of

model = tf.keras.Sequential(
           tf.keras.layers.Dense(2)
)

optimizer = tf.optimizers.Adam()

losses=[tf.keras.losses.mean_absolute_percentage_error]

model.compile(optimizer=optimizer, loss=losses)

model.build(input_shape=(None,2))
# ... train model and evaluate

model.unbuild() # this doesn't exist
model.build(input_shape=(None,3))
# ... train model and evaluate on different dataset

Anyone knows a clean way to perform this?


Solution

  • You can create one backbone model and reuse it to build as many models as you want with different inputs layer, the backbone model's parameters will remain the same for all new model you create, build new backbone model if you want reset parameters, example code here:

    import tensorflow as tf
    from tensorflow.keras import layers, models
    import numpy as np
    
    input_shape_b = (16, )
    # Backbone model
    def build_backbone_model():
        inputs_b = layers.Input(shape=input_shape_b)
        h = layers.Dense(256, 'relu')(inputs_b)
        outputs_b = layers.Dense(1, 'sigmoid')(h)
        return models.Model(inputs_b, outputs_b, name="Backbone")
        
    backbone_model = build_backbone_model()
    backbone_model.summary()
    
    def new_model_reuse_backbone(input_shape, name):
        inputs = layers.Input(shape=input_shape)
        h = layers.Dense(input_shape_b[0], 'relu')(inputs)
        outputs = backbone_model(h)
        return models.Model(inputs, outputs, name=name)
    
    # Will use backbone model we defined before
    new_model_0 = new_model_reuse_backbone((32, ), "new_model_0")
    new_model_0.summary()
    
    # Rebuild will reset backbone model's parameters
    backbone_model = build_backbone_model()
    new_model_1 = new_model_reuse_backbone((256, ), "new_model_1")
    new_model_1.summary()
    

    Outputs:

    Model: "Backbone"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #
    =================================================================
    input_1 (InputLayer)         [(None, 16)]              0
    _________________________________________________________________
    dense (Dense)                (None, 256)               4352
    _________________________________________________________________
    dense_1 (Dense)              (None, 1)                 257
    =================================================================
    Total params: 4,609
    Trainable params: 4,609
    Non-trainable params: 0
    _________________________________________________________________
    Model: "new_model_0"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #
    =================================================================
    input_2 (InputLayer)         [(None, 32)]              0
    _________________________________________________________________
    dense_2 (Dense)              (None, 16)                528
    _________________________________________________________________
    Backbone (Functional)        (None, 1)                 4609
    =================================================================
    Total params: 5,137
    Trainable params: 5,137
    Non-trainable params: 0
    _________________________________________________________________
    Model: "new_model_1"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_3 (InputLayer)         [(None, 256)]             0
    _________________________________________________________________
    dense_3 (Dense)              (None, 16)                4112
    _________________________________________________________________
    Backbone (Functional)        (None, 1)                 4609
    =================================================================
    Total params: 8,721
    Trainable params: 8,721
    Non-trainable params: 0
    _________________________________________________________________