Search code examples
tensorflowdeep-learningpytorch

Does tf.math.reduce_max allows gradient flow like torch.max?


I am trying to build a multi-label binary classification model in Tensorflow. The model has a tf.math.reduce_max operator between two layers (It is not Max Pooling, it's for a different purpose).

And the number of classes is 3.

I am using Binary Cross Entropy loss and using Adam optimizer.

Even after hours of training, when I check the predictions, all the predictions are in the range 0.49 to 0.51.

It seems that the model is not learning anything and is making random predictions, which is making me think that using a tf.math.reduce_max function may be causing the problems.

However, I read on the web that the torch.max function allows back propagation of gradients through it.

When I checked the Graph in Tensorboard, I saw that the graph is showing unconnected at the tf.math.reduce_max operator. SO, does this operator allows gradients ot back propagate through it?

EDIT : Addin the code

input_tensor = Input(shape=(256, 256, 3))
base_model_toc = VGG16(input_tensor=input_tensor,weights='imagenet',pooling=None, include_top=False)

x = base_model.output

x = GlobalAveragePooling2D()(x)

x = tf.math.reduce_max(x,axis=0,keepdims=True)

x = Dense(1024,activation='relu')(x)

output_1 = Dense(3, activation='sigmoid')(x)

model_a = Model(inputs=base_model_toc.input, outputs=output_1)

for layer in base_model.layers:
    layer.trainable = True

THe tf.math.reduce_max is done along axis = 0 becasue that is what needs to be done in this model

Optimizer that I am using is Adam with initial learning rate 0.00001


Solution

  • Yes, tf.math.reduce_max does allow gradients to flow. It is easy to check (this is TensorFlow 2.x but it is the same result in 1.x):

    import tensorflow as tf
    
    with tf.GradientTape() as tape:
        x = tf.linspace(0., 2. * 3.1416, 10)
        tape.watch(x)
        # A sequence of operations involving reduce_max
        y = tf.math.square(tf.math.reduce_max(tf.math.sin(x)))
    # Check gradients
    g = tape.gradient(y, x)
    print(g.numpy())
    # [ 0.         0.         0.3420142 -0.        -0.        -0.
    #  -0.         0.         0.         0.       ]
    

    As you can see, there is a valid gradient for y with respect to x. Only one of the values is not zero, because it is the value that then resulted in the maximum value, so it is the only value in x that affects the value of y. This is the correct gradient for the operation.