Search code examples
machine-learningneural-networkbackpropagation

Validation Set in Backpropogation in a Neural Network


I have a neural network model, and so far I am running the training set forward, calculating the errors, and adjusting the weights.

As I understand it, after I do this for each training set example I need to run an example from the validation set forward and calculate the errors. When the validation set error stops decreasing, but the training set error is still decreasing it is time to stop because over-fitting is starting to occur. After we stop, we use the testing set to calculate how much error is in our network.

Please correct me if there are any mistakes so far.

My question is what error are we comparing? Are we just comparing the error of the output layer? Or are we comparing the errors from every node? If so, how exactly do we define the overall error of the network, just sum up all the errors?


Solution

  • My question is what error are we comparing?

    We are comparing the error only on the output layer. So, if you plot a error vs epoch graph, you will have two curves there. The line for training error goes down as you have more epochs. But the line for validation error goes down up to certain point before starting to go up. This indicates overfitting and you want to find the last point where the validation error was lowest.

    Note that you are talking about each individual samples while I am talking about epochs. For batch methods these errors are usually plotted after one iteration over the data set (training or validation). So each point on the plot is the mean error or mean squared error from that epoch.


    Also, if we have more than 1 output, are we just taking the sum of the errors in the output layer, or should it be some kind of weighted sum?

    It's interesting for the multiple output case. Basically we are trying to find the early stopping point to stop training the weights. On the very last layer of multiple output network, the weights are being trained using different error derivatives and can possibly have different optimal early stopping points. You may want to plot them separately if you think that is the case. Otherwise, simple sum of error is sufficient. Weighted sum would mean that you care to optimize for on output over another, even when that causes other one(s) to over/under train.

    If you are thinking about implementing separate early stopping points, you can use sum of MSEs to get stopping point for all internal weights that depend on all error derivatives. For the weights on the last layer, use their corresponding MSEs to get their separate stopping points.

    Let's say I have 60% training, 20% validation, and 20% test set. For each epoch, I run through the 60 training set samples while adjusting the weights on each sample and also calculating the error on each validation sample.

    Another way to do the weight update is to calculate the updates for each sample and then apply an average of all updates at the end of the epoch. If your training data has noise/outliers/misclassified samples, this is good. For example, couple outliers will not be able to massively distort the weights since their 'bad' updates will get averaged out with other 'good' updates.

    Since there are only 1/3 as many validation samples as training samples, do I run through the validation 3 times for each epoch?

    Why do we iterate over the validation set? Do we calculate error in validation to get weight updates? No. We do all our updating only using the training set. Validation is only their to see how our trained model generalizes outside of training data. Think of it as a test before the test you run with test set. Now, does it make sense to run over the validation set 3 times in each epoch? No, it doesn't.

    I use the last calculated weights for online learning correct?

    Yes. Error calculation and weight updates happen as new samples come in.

    When we use the test set to calculate the error of our final model, are we using mse for this or does it even really matter too much which we use?

    If your model is producing real valued output, then use MSE. If you system is trying to solve a classification problem, use classification error. i.e. 10% classification error, meaning 10% of the test set was misclassified by your model during test.