Search code examples

How to freeze Auto Encoder Layers in tensorflow

This is an autoencoder network, the first part is encoder and second part is decoder. I want to freeze the first three convolution layers and save the encoder parts. can you help me how can i do it? Thank you

def encoder(input_img):
    #input = 28 x 28 x 1 (wide and thin)
    conv1 = Conv2D(64, (2,2), activation='relu', padding='same')(input_img) #28 x 28 x 32
    conv2 = BatchNormalization()(conv1)
    conv3 = Conv2D(32, (2,2), activation='relu', padding='same')(conv2)
    conv4 = BatchNormalization()(conv3)
    pool5 = MaxPooling2D(pool_size=(2,2))(conv4) #14 x 14 x 32
    conv6 = Conv2D(16, (2,2), activation='relu', padding='same')(pool5) #14 x 14 x 64
    conv7 = BatchNormalization()(conv6)
    conv8 = Conv2D(8, (2,2), activation='relu', padding='same')(conv7)
    conv9 = BatchNormalization()(conv8)
    conv10 = Conv2D(4, (2,2), activation='relu', padding='same')(conv9)
    return conv10

def decoder(conv11):    
    conv12 = Conv2D(4, (2,2), activation='relu', padding='same')(conv11)
    conv13 = Conv2D(8, (2,2), activation='relu', padding='same')(conv12) #7 x 7 x 128
    conv14 = BatchNormalization()(conv13)
    conv15 = Conv2D(16, (2,2), activation='relu', padding='same')(conv14)
    conv16 = BatchNormalization()(conv15)
    conv17 = Conv2D(32, (2,2), activation='relu', padding='same')(conv16) #7 x 7 x 64
    conv18 = BatchNormalization()(conv17)
    conv19 = Conv2D(64, (2,2), activation='relu', padding='same')(conv18)
    conv20 = BatchNormalization()(conv19)
    up21 = UpSampling2D((2,2))(conv20) #14 x 14 x 64
    decoded = Conv2D(3, (2,2), activation='sigmoid', padding='same')(up21) # 28 x 28 x
    return decoded

autoencoder = Model(input_img, decoder(encoder(input_img)))
autoencoder.compile(loss='mae', optimizer = 'SGD')

train = np.concatenate((normal[0:1900,:,:,:],un_informative[0:1900,:,:,:]),axis=0) 
valid = np.concatenate((normal[1900:,:,:,:],un_informative[1900:,:,:,:]),axis=0)
history = ,train , batch_size=batch_size,epochs=200,verbose=1, validation_data=(valid, valid))


  • After you compile your model you can choose which layers to freeze by specifying:

    layer_x.trainable = False

    So I would suggest doing the following:

    layers_to_freeze = [name of the layers]
    for layer in model.layers:
        if layer in layers_to_freeze: