Search code examples
kerasconvolutionmasking

Keras freeze specific weights with mask


I am new in Keras. I want to implement a layer where not all the weights will update. For example, in the following code, I want the dilation layer will update in a way that some center weights are never updated. For say, the shape of each feature matrix (out of 1024) in the dilation layer is 448, 448 and a block of 8x8 at the center of all feature matrices will never be updated, i.e. the 8x8 block is a (non-trainable) mask to the feature matrices.

input_layer=Input(shape=(896,896,3))
new_layer = Conv2D(32, kernel_size=(3,3), padding="same", activation='relu', kernel_initializer='he_normal')(input_layer)
new_layer = MaxPooling2D(pool_size=(2, 2), strides=(2,2), padding='same', data_format=None)(new_layer)
new_layer = Conv2D(64, kernel_size=(3,3), padding='same', activation='relu', kernel_initializer='he_normal')(new_layer)
new_layer = Conv2D(1024, kernel_size=(7,7), dilation_rate=8, padding="same", activation='relu', kernel_initializer='he_normal', name='dialation')(new_layer)
new_layer = Conv2D(32, kernel_size=(1,1), padding="same", activation='relu', kernel_initializer='he_normal')(new_layer)
new_layer = Conv2D(32, kernel_size=(1,1), padding="same", activation='relu', kernel_initializer='he_normal')(new_layer)

model = Model(input_layer, new_layer)

I was trying with the Keras's custom layer [link], but it was difficult for me to understand. Anyone would please help.

UPDATE: I added the following figure for a better understanding. The dilation layer contains 1024 features. I want the middle region of each feature to be non-trainable (static).

image of dilation layer


Solution

  • Use this mask for both cases:

    mask = np.zeros((1,448,448,1))
    mask[:,220:228,220:228] = 1
    

    Replacing part of the feature

    If you replace part of the feature with constant values, this means the feature will be static, but it will still participate in backpropagation (because weights will still be multiplied and summed for this part of the image and there is a connection)

    constant = 0 (will annulate kernel, but not bias) 
    
    def replace(x):
        return x*(1-mask) + constant*mask
    
    #before the dilation layer
    new_layer=Lambda(replace)(new_layer) 
    

    Keeping the feature value, but stopping backpropagation

    Here, the weights of the dilation layer and further will be updated normally, but the weights before the dilation layer will not receive the influence of the central region.

    def stopBackprop(x):
        stopped=K.stop_gradients(x)
        return x*(1-mask) + stopped*mask
    
    #before the dilation layer
    new_layer=Lambda(stopBackprop)(new_layer)