I have a fairly small amount of data so I decided to try cross validation, in order to get predictions for all data., like so:
for train_index, test_index in KFold(9, shuffle = True, random_state = 42).split(range(len(df))):
train_data, test_data = df[train_index], df[test_index]
train_images, test_images = images[train_index], images[test_index]
train_labels, test_labels = labels[train_index], labels[test_index]
model.fit([train_data, train_images], train_labels, epochs = 100, batch_size = 5)
model.predict([test_data, test_images])
It is my understanding that doing this will train a new model every time (total of 9 times). However if this is the case then my loss output makes no sense:
The top grey curve is the first iteration and it starts from nearly 1, then goes down. Subsequent iterations are then all significantly lower.
I'd like to understand what I'm doing wrong here - I want to fully train new networks 9 times then get predictions each time.
When calling model.fit
, the parameter estimation will continue where it left off, so you're right, once you reach the second iteration of your loop, you will be using whatever you achieved in the first iteration.
To avoid this, you will want to reset the parameters of your model between each iteration. One approach would be to simply create the model from scratch in each iteration. The only thing to be aware of is that by default, weights are initialized with random values (to avoid local extrema of the objective), so if you simply reinitialize the model entirely, you will be using different random starting points. To avoid this, and to make sure that you use the same starting point for each iteration, you could either fix the initial values (e.g. by fixing the random seed), or just go with what Keras gives you, then use model.save
prior to the first iteration, and load_model
at the start of each iteration, i.e. do what amounts to
from keras.models import load_model
model = ...
model.save('initial.h5')
for ... in KFold
model = load_model('initial.h5')
model.fit(...)
model.predict(...)
Not that you necessarily have to worry about this: You could also simply choose to view the random initialization as part of the abstract statistical model whose generalization error you are trying to estimate and just be happy with a new starting point for each iteration.