Search code examples
tensorflowkerasgradientobject-detection

Stopping Gradient back prop through a particular layer in keras


x = Conv2D(768, (3, 3), padding='same', activation='relu', kernel_initializer='normal', 
           name='rpn_conv1',trainable=trainable)(base_layers)

x_class = Conv2D(num_anchors, (1, 1), activation='sigmoid', kernel_initializer='uniform', 
                 name='rpn_out_class',trainable=trainable)(x)

    # stop gradient backflow through regression layer
x_regr = Conv2D(num_anchors * 4, (1, 1), activation='linear', kernel_initializer='zero', 
                name='rpn_out_regress',trainable=trainable)(x)

How to use K.stop_gradient() to stop gradient back-prop via the regression layer (x_reg) alone?


Solution

  • You need a Lambda layer for using custom functions.

    x_regr_constant = Lambda(
                              lambda x: K.stop_gradient(x), 
                              output_shape=notNecessaryWithTensorflow
                            )(x_regr)