Search code examples
pythontensorflowbatchsize

Better / Faster Result with smaller batch Size linear classifier


I'm currently training multiple linear classifier with tensorflow and I found something which is strange.

If the batch_size is small, my results are better ( the model learn faster) I'm working on FashionMNIST

epochs = 300
batch_size = 5000

# Create and fit model
model = tf.keras.Sequential()
model.add(Dense(1, activation="linear", input_dim=28*28))
model.add(Dense(10, activation="softmax", input_dim=1))
model.compile(optimizer=Adam(), loss=[categorical_crossentropy], metrics=[categorical_accuracy])
model.fit(x_train, y_one_hot_train, validation_data=(x_val, y_one_hot_val), epochs=epochs, batch_size=batch_size)

Results

Batch-size : 20000 and 200 epochs

loss: 2.7494 - categorical_accuracy: 0.2201 - val_loss: 2.8695 - val_categorical_accuracy: 0.2281

Batch-size : 10000 and 200 epochs

loss: 1.7487 - categorical_accuracy: 0.3336 - val_loss: 1.8268 - val_categorical_accuracy: 0.3331

Batch-size : 2000 and 200 epochs

loss: 1.2906 - categorical_accuracy: 0.5123 - val_loss: 1.3247 - val_categorical_accuracy: 0.5113

Batch-size : 1000 and 200 epochs

loss: 1.1080 - categorical_accuracy: 0.5246 - val_loss: 1.1261 - val_categorical_accuracy: 0.5273

Do you know why I got these kind of results ?


Solution

  • Batch size impacts learning significantly. What happens when you put a batch through your network is that you average the gradients. The concept is that if your batch size is big enough, this will provide a stable enough estimate of what the gradient of the full dataset would be. By taking samples from your dataset, you estimate the gradient while reducing computational cost significantly. The lower you go, the less accurate your esttimate will be, however in some cases these noisy gradients can actually help escape local minima. When it is too low, your network weights can just jump around if your data is noisy and it might be unable to learn or it converges very slowly, thus negatively impacting total computation time.