Search code examples
pythonneural-networkpytorchcross-validation

Cross-validation of neural network: How to treat the number of epochs?


I'm implementing a pytorch neural network (regression) and want to identify the best network topology, optimizer etc.. I use cross validation, because I have x databases of measurements and I want to evaluate whether I can train a neural network with a subset of the x databases and apply the neural network to the unseen databases. Therefore, I also introduce a test database, which I doesn't use in the phase of the hyperparameter identification. I am confused on how to treat the number of epochs in cross validation, e.g. I have a number of epochs = 100. There are two options:

  1. The number of epochs is a hyperparameter to tune. In each epoch, the mean error across all cross validation iterations is determined. After models are trained with all network topologies, optimizers etc. the model with the smallest mean error is determined and has parameters like:
    -network topology: 1
    -optimizer: SGD
    -number of epochs: 54
    To calculate the performance on the test set, a model is trained with exactly these parameters (number of epochs = 54) on the training and the validation data. Then it is applied and evaluated on the test set.

  2. The number of epochs is NOT a hyperparameter to tune. Models are trained with all the network topologies, optimizers etc. For each model, the number of epochs, where the error is the smallest, is used. The models are compared and the best model can be determined with parameters like:
    -network topology: 1
    -optimizer: SGD
    To calculate the performance on the test data, a “simple” training and validation split is used (e.g. 80-20). The model is trained with the above parameters and 100 epochs on the training and validation data. Finally, a model with a number of epochs yielding the smallest validation error, is evaluated on the test data.

Which option is the correct or the better one?


Solution

  • The number of epochs is better not to be fine-tuned. Option 2 is a better option. Actually, if the # of epochs is fixed, you need not to have validation set. Validation set gives you the optimal epoch of the saved model.