Search code examples
tensorflowmachine-learningkerasautoencoderloss-function

Correct way to calculate MSE for autoencoders with batch-training


Suppose you have a network representing an autoencoder (AE). Let's assume it has 90 inputs/outputs. I want to batch-train it with batches of size 100. I will denote my input with x and my output with y.

Now, I want to use the MSE to evaluate the performance of my training process. To my understanding, the input/output dimensions for my network are of size (100, 90).

The first part of the MSE calculation is performed element-wise, which is

(x - y)²

so I end up with an matrix of size (100, 90) again. For better understanding of my problem, I will arbitrarily draw a matrix of how this looks now:

[[x1 x2 x3 ... x90],    # sample 1 of batch
 [x1 x2 x3 ... x90],    # sample 2 of batch
 .
 .
 [x1 x2 x3 ... x90]]    # sample 100 of batch

I have stumbled across various versions of calculating the error from now on. Goal of all versions is to reduce the matrix to a scalar, which can then be optimized.

Version 1:

Sum over the quadratic errors in the respective sample first, then calculate the mean of all samples, e.g.:

v1 = 
[ SUM_of_qerrors_1,        # equals sum(x1 to x90)
  SUM_of_qerrors_2,
  ...
  SUM_of_qerrors_100 ]

result = mean(v1)

Version 2:

Calculate mean of quadratic errors per sample, then calculate the mean over all samples, e.g.:

v2 = 
[ MEAN_of_qerrors_1,        # equals mean(x1 to x90)
  MEAN_of_qerrors_2,
  ...
  MEAN_of_qerrors_100 ]

result = mean(v2)

Personally, I think that version 1 is the correct way to do it, because the commonly used crossentropy is calculated in the same manner. But if I use version 1, it isn't really the MSE.

I've found a keras example here (https://keras.io/examples/variational_autoencoder/), but unfortunately I wasn't able to figure out how keras does this under the hood with batch training.

I would be grateful either for a hint how this is handled under the hood by keras (and therefore tensorflow) or what the correct version is.

Thank you!


Solution

  • The version 2, i.e. computing the mean of quadratic errors per sample and then compute the mean of the resulting numbers, is the one which is done in Keras:

    def mean_squared_error(y_true, y_pred):
        return K.mean(K.square(y_pred - y_true), axis=-1)
    

    However, note that taking the average over samples is done in another part of the code which I have explained extensively here and here.