Search code examples
machine-learningdeep-learningcomputer-visionobject-detectionsupervised-learning

Cross validation in the context of deep learning | Object Detection


I am working on modeling a dataset from object detection. I am relatively new to deep learning. I am having a hard time extending the idea of cross-validation in the context of deep learning. Usually, the train time is huge with deep network and k-fold CV is not a reasonable approach. So, probably 1-fold cross-validation makes more sense (I have seen people use this in practice). I am trying to reason this choice and thinking about the idea behind cross-validation: hyper-parameter tuning, or quantify when the modeling starts to over-fit. My questions are the following:

  1. What about the random sampling error with a 1-fold CV? My thoughts: with k-fold CV this error is averaged out when k>1. Also, with k=1, the hyper-parameter also doesn't seem reasonable to me: the values we end up finding can be coupled with the (random) sample we called validation set. So, what's the point of a 1-fold CV?

  2. There's already a crunch of data points in the data I am working with. I have around ~4k images, 2 categories (object+background), bounding boxes for each image. I think it's common wisdom that deep networks learn better with more data. Why would I want to reduce my training set by keeping aside a validation set in this context? I don't see any clear advantages. On the contrary, it seems like using the entire dataset to train can lead to a better object detection model. If this is true, then how would one know when to stop, i.e. I could keep training, without any feedback into whether the model has started overfitting?

  3. How are production models deployed? I guess I have never thought about this one much while taking courses. The approach was pretty clear that you always have a train, validation, test set. In actual settings, how do you leverage the entire data to create a production model? (probably connected to #2, i.e. dealing with practical aspects like how much to train etc.)


Solution

    1. Public Computer Vision datasets in the domain of Object Detection are usually large enough that this isn't an issue. How much of an issue it is in your scenario can be shown by the gap in performance between validation and test set. Cross validation with n = 1 essentially means having a fixed validation set.
    2. You want to keep the validation set in order to tune the parameters of your model. Increasing the number of weights will surely increase the performance on the training set but you want to check how this behaves on unseen data e.g. the validation set. That said, many people will tune parameters according to the performance on the validation set and then do one more training where they combine the training and validation data before finally testing it on the test set.
    3. I think this is already answered in 2. You can extend this by training on all 3 sets but whatever performance you achieve on it will not be representative. The number of epochs/iterations you want to train for should therefore be decided before merging the data.

    You have to decide what it is you want to optimize for. Most papers optimize for performance on the test set which is why it should never be used for training or validating parameter choices. In reality you might often favour a "better" model by including validation and test data into the training. You will never know how much "better" this model is until you find another test set. You're also risking that something "strange" will happen when including the test data. You're essentially training with closed eyes.