I am building a vanilla DQN model to play the OpenAI gym Cartpole game.
However, in the training step where I feed in the state as input and the target Q values as the labels, if I use model.fit(x=states, y=target_q)
, it works fine and the agent can eventually play the game well, but if I use model.train_on_batch(x=states, y=target_q)
, the loss won't decrease and the model will not play the game anywhere better than a random policy.
I wonder what is the difference between fit
and train_on_batch
? To my understanding, fit
calls train_on_batch
with a batch size of 32 under the hood which should make no difference since specifying the batch size to equal the actual data size I feed in makes no difference.
The full code is here if more contextual information is needed to answer this question: https://github.com/ultronify/cartpole-tf
model.fit
will train 1 or more epochs. That means it will train multiple batches. model.train_on_batch
, as the name implies, trains only one batch.
To give a concrete example, imagine you are training a model on 10 images. Let's say your batch size is 2. model.fit
will train on all 10 images, so it will update the gradients 5 times. (You can specify multiple epochs, so it iterates over your dataset.) model.train_on_batch
will perform one update of the gradients, as you only give the model on batch. You would give model.train_on_batch
two images if your batch size is 2.
And if we assume that model.fit
calls model.train_on_batch
under the hood (though I don't think it does), then model.train_on_batch
would be called multiple times, likely in a loop. Here's pseudocode to explain.
def fit(x, y, batch_size, epochs=1):
for epoch in range(epochs):
for batch_x, batch_y in batch(x, y, batch_size):
model.train_on_batch(batch_x, batch_y)