Search code examples
pythontensorflowtensorflow2.0activation-function

Add custom activation function to be used with a string


I followed this response doing:

get_custom_objects().update(act_dispatcher)

Where act_dispatcher is a dictionary with all the activation functions I want to add like {'fun_1':fun_1, 'fun_2': fun_2}.

The first thing that caught my attention is that at the beginning, if I add nothing, get_custom_objects() returns an empty dict {}. After adding the function I check that each call to get_custom_objects() has what I said it to be.

However, I am getting ValueError: Unknown activation function:<my_fun>

I added the line

assert 'my_func', in get_custom_objects().keys()

tf.keras.layers.Dense(128, activation='my_func')

And the assertion pass without problem having the error mentioned at the keras Dense init.


The error occurs at deserialize_keras_object where:

  • custom_objects is None
  • _GLOBAL_CUSTOM_OBJECTS is {} (Probably this shouldn't be empty).
  • module_objects.get(object_name) returns None (module_objects seams correct).

Installation

I am using an anaconda environment. Due to an error that I had first with from keras.utils.generic_utils import get_custom_objects I installed keras-applications with conda install -c conda-forge keras-applications


Solution

  • I can running those code without any error, you can try to change from tensorflow.keras.utils.generic_utils import get_custom_objects to from tensorflow.keras.utils import get_custom_objects see if it helps:

    from tensorflow.keras import backend as K
    from tensorflow.keras.utils import get_custom_objects
    from tensorflow.keras.layers import Activation, Conv2D
    from tensorflow.keras.models import Sequential
    
    def my_func(x, beta=1.0):
        return x * K.sigmoid(beta * x)
    
    model = Sequential()
    model.add(Conv2D(64, (3, 3)))
    model.add(Activation(my_func))
    
    get_custom_objects().update({'my_func': Activation(my_func)})
    
    model.add(Conv2D(64, (3, 3), activation='my_func'))