Search code examples
machine-learningdata-analysiscross-validation

Do i need multiple models for cross validation?


So i have seen differing implementations of cross validation. I'm currently using pytorch to train a neural network. My current layout looks like this: I have 6 discrete Datasets. 5 are used for cross validation.

Network_1 trains on Datasets: 1,2,3,4 computes loss on 5
Network_2 trains on Datasets: 1,2,3,5 computes loss on 4
Network_3 trains on Datasets: 1,2,4,5 computes loss on 3
Network_4 trains on Datasets: 1,3,4,5 computes loss on 2
Network_5 trains on Datasets: 2,3,4,5 computes loss on 1

Then comes epoch 2 and i do the exact same again:

Network_1 trains on Datasets: 1,2,3,4 computes loss on 5
Network_2 trains on Datasets: 1,2,3,5 computes loss on 4
Network_3 trains on Datasets: 1,2,4,5 computes loss on 3
Network_4 trains on Datasets: 1,3,4,5 computes loss on 2
Network_5 trains on Datasets: 2,3,4,5 computes loss on 1

For testing on the Dataset 6 i should merge the predictions from all 5 networks and take the average score of the prediction (still have to do the averaging of the prediction matrices).

Have i understood cross validation correctly? Is this how it's supposed to work? Will this work properly? I put effort on not testing with data that i already trained on. I still dont

Would greatly appreciate the help :)


Solution

  • You can definitely apply cross validation with neural network, but because neural network are computationally demanding models this is not usually done. To reduce variance, there are other techniques which are ordinarily applied in neural networks, such as early stopping or dropout.

    That being said, I am not sure you're applying it in the right way. You should train across all the epochs, so that:

    Network_1 trains on Datasets: 1,2,3,4 up to the end of training. Then computes loss on 5
    Network_2 trains on Datasets: 1,2,3,5 up to the end of training. Then computes loss on 4
    Network_3 trains on Datasets: 1,2,4,5 up to the end of training. Then computes loss on 3
    Network_4 trains on Datasets: 1,3,4,5 up to the end of training. Then computes loss on 2
    Network_5 trains on Datasets: 2,3,4,5 up to the end of training. Then computes loss on 1
    

    Once each network is trained up to the end of training (so across all the epochs), and validated on the left-out dataset (called validation dataset), you can average the scores you obtained.
    This score (and indeed the real point of cross validation) should give you a fair evaluation of your model, which should not drop when you're going to test your it on the test set (the one you left out from training from the beginning).

    Cross validation is usually used in pair with some form of grid search to produce an unbiased form of evaluation of different models you want to compare. So if you want for example to compare NetworkA and NetworkB which differ with respect to some parameters, you use cross validation for NetworkA, cross validation for NetworkB, and then take that one having the highest cross validation score as final model.

    As last step, once you decided which is the best model, you usually retrain your model taking all the data you have in the train set (i.e. datasets 1,2,3,4,5 in your case) and test this model on the test set (Dataset 6).