Search code examples
pythonmodelneural-networktraining-data

initialization neural network model kfold


I initialized the network model before the k-fold starts.

Does that mean that the model trains for the first fold and this model with the trained weights is taken for the second fold and so on? What if the last fold is bad and the whole model is bad?


Solution

  • It depends on what you mean by "initialized the network", you should show some snippet of code to make people understand your problem.

    In principle, k-fold cross validation is a technique used to have a better estimation of the performance of a model. The concept is easy, whithout k-fold you just split dataset into train/test, you use unseen samples in test set to estimate a performance/error, but usually data is not perfect, it's a bit dirty, so it can happen that "bad" samples end up to be in the test set, and when you use them to estimate the performance of a model you'll get a value which does not represent the real one.

    To reduce the error in the estimate of the error/performance, you split dataset into k equally distributed folds, then you train k times a NEW model (so weights are each time initialized from scratch), testing each time on one of the k different "folds" and training it on the remaining samples of the dataset.

    enter image description here

    By doing so, you'll have k different estimates of the error/performance of your model.

    If you want to have a single value as measure, you just average the results. OF course you can use the results to do whatever you like, you can SELECT the best model, you can average the weights of the k models, you can average the "top n" model weights and so on.

    So, answering your question, NO you don't keep your weights. Upon the k models you are training, it can happen that one of them is "bad", but you are using k-fold just to VALIDATE your model, not to better train it! After validation you can decide what to do. You are looking for a measure of how much your model is "good", by doing this you are just more sure your result is near the real value.

    If you want to use dataset to reduce other types of errors (like overfitting or other stuff) you should check ensemble methods.

    I hope this was helpful