I want to train my model for different batch sizes i.e: [64, 128] I am doing it with for loop like below
batch_sizes = [128,256]
for i in range(len(batch_sizes)):
history = model.fit(x_train, y_train, batch_sizes[i], epochs=epochs,
callbacks=[early_stopping, chk], validation_data=(x_test, y_test))
for above code my model produce following results:
Epoch 1/2
311/311 [==============================] - 157s 494ms/step - loss: 0.2318 -
f1: 0.0723
Epoch 2/2
311/311 [==============================] - 152s 488ms/step - loss: 0.1402 -
f1: 0.4360
Epoch 1/2
156/156 [==============================] - 137s 877ms/step - loss: 0.1197 -
f1: **0.5450**
Epoch 2/2
156/156 [==============================] - 136s 871ms/step - loss: 0.1132 -
f1: 0.5756
it looks like the model continues training after completing training for batch size 64, i.e I want to get my model trained for the next batch from scratch, how can I do it kindly guide me. p.s: what i have tried:
batch_sizes = [128,256]
for i in range(len(batch_sizes)):
history = model.fit(x_train, y_train, batch_sizes[i], epochs=epochs,
callbacks=[early_stopping, chk], validation_data=(x_test, y_test))
it also did not worked
You can write a function to define a model, and you would need to call that before the subsequent fit
calls. If your model is contained within model
, the weights are updated during training, and they stay that way after the fit call. That is why you need to redefine the model. This can help you
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
X = np.random.rand(1000,5)
Y = np.random.rand(1000,1)
def build_model():
model = Sequential()
return model
batch_sizes = [128,256]
for i in range(len(batch_sizes)):
model = build_model()
history = model.fit(X, Y, batch_sizes[i], epochs=epoch, verbose=2)
model.save('Model_' + str(batch_sizes[i]) + '.h5')
Then, the output looks like:
Epoch 1/2
8/8 - 0s - loss: 0.3164
Epoch 2/2
8/8 - 0s - loss: 0.1367
Epoch 1/2
4/4 - 0s - loss: 0.7221
Epoch 2/2
4/4 - 0s - loss: 0.4787