Search code examples
tensorflowkeras

How to replace a model layer using TensorFlow 2.16?


The following works with TensorFlow 2.15:

from tensorflow.keras.layers import Input, Dense, BatchNormalization
from tensorflow.keras.models import Model

inputs = Input(shape=(4,))
x = Dense(5, activation='relu')(inputs)
predictions = Dense(3, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)
model.compile(loss='categorical_crossentropy', optimizer='nadam')

print(model.layers)
model._self_tracked_trackables[1] = BatchNormalization()
print(model.layers)

Output:

[<keras.src.engine.input_layer.InputLayer object at 0x7f9324ba10c0>, <keras.src.layers.core.dense.Dense object at 0x7f931c1144c0>, <keras.src.layers.core.dense.Dense object at 0x7f931c116c80>]
[<keras.src.engine.input_layer.InputLayer object at 0x7f9324ba10c0>, <keras.src.layers.normalization.batch_normalization.BatchNormalization object at 0x7f9324ba3280>, <keras.src.layers.core.dense.Dense object at 0x7f931c116c80>]

How can this be achieved with TensorFlow 2.16?

The model no longer has _self_tracked_trackables:

AttributeError: 'Functional' object has no attribute '_self_tracked_trackables'

And trying to swap out a layer like so:

model.layers[1] = BatchNormalization()

or so

model._layers[1] = BatchNormalization()

or so

model.operations[1] = BatchNormalization()

does not change the content of model.layers.

print(model.layers)

outputs the following not only before, but also after this assignment:

[<InputLayer name=input_layer, built=True>, <Dense name=dense, built=True>, <Dense name=dense_1, built=True>]

(Background: I'm trying to update my frugally-deep library to the new TF version, and it relies on swapping layers to convert nested sequential models to functional models.)


Solution

  • model._operations[1] = BatchNormalization()
    

    works.

    Please keep the following in mind: This code is accessing a private member of the class. It's an undocumented implementation detail, which might change again in future TensorFlow versions.