Search code examples
c++validationneural-networkfann

Train neural network with validation dataset in FANN


As some posts suggest, I start using FANN (http://leenissen.dk/fann/index.php) to do neural network stuff. It is clean and easy to understand.

However, to avoid the over-fitting problem, I need to employ an algorithm that considers validation dataset as an auxiliary. (whats is the difference between train, validation and test set, in neural networks?). Interestingly, FANN wrote that it recommends the developer to consider the over-fitting problem (http://leenissen.dk/fann/wp/help/advanced-usage/).

Now the thing is, as far as I can see, FANN does not have any function to support this feature. The training function in FANN does not provide any arguments to pass the validation dataset in, neither. Am I correct? How do FANN users train their neural networks with validation dataset? Thanks for any help.


Solution

  • You can implement this approach, i.e. dataset split, with FANN yourself, but you need to train each epoch separately, using the function fann_train_epoch.

    You start with a big dataset, which you then want to split for the different steps. The tricky thing is: You split the dataset only once, and use only the fist part to adjust the weights (training as such).

    Say, you want to have already your 2 datasets: Tran and Validation (like in the example you posted). You first need to store them in different files or arrays. Then, you can do the follwing:

    struct fann *ann;
    struct fann_train_data *dataTrain;
    struct fann_train_data *dataVal;
    

    Assuming that you have both datasets in files:

    dataTrain = fann_read_train_from_file("./train.data");
    dataVal = fann_read_train_from_file("./val.data");
    

    Then, after setting all network parameters, you train and check the error on the second dataset, one epoch at a time. This is something like:

    for(i = 1 ; i <= max_epochs ; i++) {
        fann_train_epoch(ann, dataTrain);
        train_error = fann_test_data(ann, dataTrain);
        val_error = fann_test_data(ann, dataVal);
        if ( val_error > last_val_error )
            break;
        last_val_error = val_error;
    }
    

    Of course, this condition is too simple and may stop your training loop too early, if the error fluctuate (as it commonly does: look plot below), but you get the general idea on how to use different datasets during training.

    By the way, you may want to save these errors to plot them against the training epoch and have a look after the training ended:

    enter image description here