Search code examples
deep-learningcross-validationsemantic-segmentation

K-Fold cross validation on segmentation task in deep learning


I'm new to deep learning and i wanted to do semantic segmentation task with U-Net, i heared one of the strategies to make my results better is to use cross-validation which is not very popular in deep learning.

I made some reasearch to find out how to implement it on my dataset but i couldn't find a reliable answer. can you help me how should i implement 5-fold cross validation? should i train 1 model on 5 folds? this seems not correct to me because why shouldn't i just train the model on the whole training dataset? or should i train 5 models on each fold and for inference average the 5 outputs? this might work but the inference time will be overwhelmingly high Thank you


Solution

  • K-fold cross-validation (cv) can be used to obtain better insight into how the model will generalize on unseen data.

    To perform 5-fold cv, first separate your data into five folds. Then set the first fold aside and train a model on the remaining four folds. Now evaluate the trained model on the fold that was set aside. Next take the five folds, set aside the second fold, and train a new model on the remaining four folds; evaluate this model on the second fold. Repeat this process, setting aside each other fold, until you have created five models, each which has a single validation score. Take the mean of these five validation scores, and that is your cross-validation score, which is an estimate of the performance of using the model building process (e.g. the fixed preprocessing, hyperparameters, deep learning algorithm). None of these five models will be your final model. Instead, rerun the model building process (not cross-validation), using all the data to train it. The result of this will be the final model. And the estimate of that model's performance is the cross-validation score previously found.

    Why did we perform cross-validation? The idea is that there is randomness present in the data or model building process. Given such randomness (or noise), when we create a model on a training set and evaluate it on held-out set, the performance might -- just by luck -- be better or worse than what we will see when we go off and deploy our model. If instead we look at how our model building process performed on combinations of training data and evaluation data, we will get a better indication of how the model will perform.

    Besides using cv to estimate the performance of model that is going to be deployed, another place it is common to use cv is in model selection. Come up with a series of different model building processes (for instance, different number of layers in a neural net), and select the one with the highest cv score. (Note, this cv score is an optimistic indicator of how well the model will perform on new data; related terms are "winners curse", "multiple induction problem", "multi-hypothesis testing problem", "overhyping").

    The reason cross-validation is not popular in deep learning is that it time consuming, because instead of building one model, there is requirements to build multiple. And also deep learning is often used in problems where there is lots of data, and so a train-validation-test split method is hoped to be sufficient enough for modeling building, model selection, and model validation respectively.