Search code examples
pythontensorflowkeras

TypeError: Using Custom Activation Function while Mixed Precision Enabled?


I was trying to use a custom activation in mixed-precision enabled training pipelines but faced the following error:

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float16 of argument 'x'.

Reproduce

Enabling Mixed precision...

import tensorflow as tf 

policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
print('Mixed precision enabled')

Custom activation...

def ARelu(x, alpha=0.90, beta=2.0):
    alpha = tf.clip_by_value(alpha, clip_value_min=0.01, clip_value_max=0.99)
    beta  = 1 + tf.math.sigmoid(beta)
    return tf.nn.relu(x) * beta - tf.nn.relu(-x) * alpha

Training...

import tensorflow as tf

(xtrain, ytrain), (xtest, ytest) = tf.keras.datasets.mnist.load_data()

def pre_process(inputs, targets):
    inputs  = tf.expand_dims(inputs, -1)
    targets = tf.one_hot(targets, depth=10)
    return tf.divide(inputs, 255), targets

train_data = tf.data.Dataset.from_tensor_slices((xtrain, ytrain)).\
    take(10_000).shuffle(10_000).batch(8).map(pre_process)
test_data = tf.data.Dataset.from_tensor_slices((xtest, ytest)).\
    take(1_000).shuffle(1_000).batch(8).map(pre_process)

model = tf.keras.Sequential([
                             
            tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1),
                                   input_shape=(28, 28, 1), activation=ARelu),
            tf.keras.layers.MaxPool2D(pool_size=(2, 2)),

            tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), 
                                   activation=ARelu),
            tf.keras.layers.MaxPool2D(pool_size=(2, 2)),

            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation=ARelu), 
            tf.keras.layers.Dense(10, activation='softmax', dtype=tf.float32)]) 

opt = tf.keras.optimizers.Adam()

model.compile(loss='categorical_crossentropy', optimizer=opt)
history = model.fit(train_data, validation_data=test_data, epochs=10)

# ------------------

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float16 of argument 'x'.

However, without mixed-precision, it works. I understand the problem simply types miss match but where I should look into it?

Additionally, while trying to solve it, I've found that using tf.keras.mixed_precision.LossScaleOptimizer is safe to avoid numeric underflow. Is it something that we should use for mixed-precision training?


Solution

  • The solution of the above problem is casting your defined alpha and beta into float16 rather than casting the input of your activation layer to Float32.

    DETAILS:

    In reality, the main reason for using MP is to reduce the memory footprint observed during training. The method for doing so is by storing the output of the layer in a FP16, since memory consumption is dominated by the storage of activations rather than weights. By recasting your layer output to FP32 in the custom activation function, you are losing these savings and even requiring more memory to train the model compared to using Full precision because 2 copies exist for your activation.