Search code examples
pythontensorflowkerasloss-functionautoencoder

Correct implementation of Autoencoder MSE loss function in TF2/Keras


Can someone explain me the difference between the following two please?

Assuming a vanilla Autoencoder with real-valued inputs, according to this and this sources, its loss function should be as follows. In other words a) for each element in an example we calculate the square difference, b) we perform a summation over all elements of the example, and c) we take the mean over all examples.

def MSE_custom(y_true, y_pred):
    return tf.reduce_mean(
        0.5 * tf.reduce_sum(
            tf.square(tf.subtract(y_true, y_pred)),
            axis=1
            )
        )

However, in the majority of implementations I see: autoencoder.compile(loss='mse', ...).

I fail to see how the two are the same. Consider this example:

y_true = [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]]
y_pred = [[0.0, 0.8, 0.9], [0.5, 0.7, 0.6], [0.8, 0.7, 0.5]]

result1 = MSE_custom(y_true, y_pred)  # 0.355 

mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.AUTO)
result2 = mse(y_true, y_pred)  # 0.237

What am I missing?


Solution

  • There are two differences.

    1. The Keras loss averages over all dimensions, i.e. your reduce_sum should be replaced by reduce_mean.
    2. The Keras loss does not multiply by 0.5.

    In your case, you have three dimensions, so we can get to the Keras loss from your result by dividing by 3 (to simulate the averaging) and multiplying by 2. As it turns out, 0.355 * 2/3 == 0.237 (roughly).

    These changes might throw you off, but they are ultimately irrelevant because both dividing by N as well as multiplying by 2 are constant factors, and as such only provide a constant factor to the gradients as well.

    Edit: The following computation should give you the same result as the Keras loss:

    mse_custom = tf.reduce_mean((y_true - y_pred)**2)
    

    I use overloaded Python operators instead of TF ops for simplicity (subtract/square). This simply averages the entire 2D matrix at once, which is the same as computing an average over axis 1 and then averaging that over axis 0.