I am trying to solve a simple binary classification problem using LSTM. I am trying to figure out the correct loss function for the network. The issue is, when I use the binary cross-entropy as loss function, the loss value for training and testing is relatively high as compared to using the mean squared error (MSE) function.
Upon research, I came across justifications that binary cross-entropy should be used for classification problems and MSE for the regression problem. However, in my case, I am getting better accuracies and lesser loss value with MSE for binary classification.
I am not sure how to justify these obtained results. Why not use mean squared error for classification problems?
I'd like to share my understanding of the MSE and binary cross-entropy functions.
In the case of classification, we take the argmax
of the probability of each training instance.
Now, consider an example of a binary classifier where model predicts the probability as [0.49, 0.51]
. In this case, the model will return 1
as the prediction.
Now, assume that the actual label is also 1
.
In such a case, if MSE is used, it will return 0
as a loss value, whereas the binary cross-entropy will return some "tangible" value.
And, if somehow with all data samples, the trained model predicts a similar type of probability, then binary cross-entropy effectively return a big accumulative loss value, whereas MSE will return a 0
.
According to the MSE, it's a perfect model, but, actually, it's not that good model, that's why we should not use MSE for classification.