Search code examples
tensorflowkeraskeras-layer

Dynamic Activation Function in Keras


I'm working on a research project which involves replacing certain ReLu activations with polynomial activations. The code I inherited is using Keras with TensorFlow backend - something I have little experience with.

Essentially I'm building a vanilla ResNet graph and I need to gradually swap out a few ReLus with custom functions. In other words, my custom activation needs to do the following:

def activation(x)
    approx = .1992 + .5002*x + .1997*x**2
    relu = tf.nn.relu(x)
    diff = (TOTAL_EPOCHS - CURRENT_EPOCH) / TOTAL_EPOCHS
    return (1-diff)*approx + diff*relu

The problem I'm having is figuring out how to make the function dynamic via the current epoch using keras and model.fit. I have tried a few things, like defining a custom layer, passing a counter variable, and trying to use tensorflow's global step variable but have run into annoying bugs with each one of those attempts. I was wondering if there's a simple way to do this that I'm overlooking? It seems like this should be trivial but I just lack experience with the framework.


Solution

  • You can use keras.callbacks.Callback to make the function dynamic via the current epoch using keras and model.fit. Here's an example to make the return value of the activation function equal to current epoch. From the MSE value you can quickly see that current epoch participates in the calculation of activation function.

    from keras.models import Model
    from keras.layers import Activation,Input
    from keras.utils.generic_utils import get_custom_objects
    import keras.backend as K
    from keras.callbacks import Callback
    
    class MonitorCallback(Callback):
        def __init__(self, CURRENT_EPOCH):
            self.parm = CURRENT_EPOCH
        def on_epoch_begin(self, epoch, logs=None):
            K.set_value(self.parm, epoch)
    
    CURRENT_EPOCH = K.variable(0)
    TOTAL_EPOCHS = 8
    def custom_activation(x):
        return CURRENT_EPOCH
    
    num_input = Input(shape=(1,))
    get_custom_objects().update({'custom_activation': Activation(custom_activation)})
    output = Activation(custom_activation)(num_input)
    model = Model(num_input,output)
    model.compile(optimizer='rmsprop',loss='mse',metrics=['mse'])
    
    model.fit(x=[1],y=[2],epochs=TOTAL_EPOCHS,callbacks=[MonitorCallback(CURRENT_EPOCH)])
    
    # print
    Using TensorFlow backend.
    Epoch 1/8
    1/1 [==============================] - 2s 2s/step - loss: 4.0000 - mean_squared_error: 4.0000
    Epoch 2/8
    1/1 [==============================] - 0s 2ms/step - loss: 1.0000 - mean_squared_error: 1.0000
    Epoch 3/8
    1/1 [==============================] - 0s 2ms/step - loss: 0.0000e+00 - mean_squared_error: 0.0000e+00
    Epoch 4/8
    1/1 [==============================] - 0s 2ms/step - loss: 1.0000 - mean_squared_error: 1.0000
    Epoch 5/8
    1/1 [==============================] - 0s 3ms/step - loss: 4.0000 - mean_squared_error: 4.0000
    Epoch 6/8
    1/1 [==============================] - 0s 3ms/step - loss: 9.0000 - mean_squared_error: 9.0000
    Epoch 7/8
    1/1 [==============================] - 0s 3ms/step - loss: 16.0000 - mean_squared_error: 16.0000
    Epoch 8/8
    1/1 [==============================] - 0s 3ms/step - loss: 25.0000 - mean_squared_error: 25.0000
    

    Please keep in mind that epoch counts from zero in keras.callbacks.Callback. You can try replacing custom_activation with your activation function.