Search code examples
javavalidationneural-networkcross-validationencog

Encog - Training a neural network with Cross-validation methods and validating the training


I would like to stop training a network once I see the error calculated from the validation set starts to increase. I'm using a BasicNetwork with RPROP as the training algorithm, and I have the following training iteration:

void trainCrossValidation(BasicNetwork network, MLDataSet training, MLDataSet validation) {

    FoldedDataSet folded = new FoldedDataSet(training);
    Train train = new ResilientPropagation(network, folded);
    CrossValidationKFold trainFolded = new CrossValidationKFold(train, KFOLDS);       
    trainFolded.addStrategy(new SimpleEarlyStoppingStrategy(validation));


    int epoch = 1;
    do {
        trainFolded.iteration();
        logger.debug("Iter. " + epoch + ": Erro = " + trainFolded.getError());
        epoch++;
    } while (!trainFolded.isTrainingDone() && epoch < MAX_ITERATIONS);

    trainFolded.finishTraining();
}

Unfortunately it is not working as expected. The method takes a huge time to execute and seems not to stop at the right moment. I wish the training be aborted at the exactly instant that the validation error begins to grow, that is, in the ideal amount of training iterations.

Is there a way that extract the validation data directly from a cross-validation folded instead of creating an another MLDataSet exclusively for validation? If yes, how to do this?

Which parameter should I use to stop the training? Can you show me the necessary modifications to do what is expected? How could I use cross-validation and SimpleEarlyStoppingStrategy together? I'm pretty confused

Thank you so much for any assistance.


Solution

  • I think there are a couple of confusion points there.

    • One thing is to stop training when the error start to increase in a different(!) set of data. This set of data is normally called validation dataset. That's valid for each training, i.e. for each loop that goes to the maximum epochs number (your do-while loop). For this, you need to keep track of the measure error on the last iteration: Run the network on the validation dataset, get the error and compare with that on the next epoch.

    • Another thing is cross validation. Here, you train the network a number of times. From the whole bunch of training procedures, you estimate the goodness of the network. It is a more complex, more robust approach, with different variations. I like this diagram.

    • Lastly: The fact that you stop training at the point where the error starts to increase doesn't mean that you found the ideal amount of training epochs. You may be caught in a local minimum, a common issue in these models.

    Hope it helps you a bit :)