Search code examples
tensorflowtensorflow2.0tensorflow2.x

batch_size in tf model.fit() vs. batch_size in tf.data.Dataset


I have a large dataset that can fit in host memory. However, when I use tf.keras to train the model, it yields GPU out-of-memory problem. Then I look into tf.data.Dataset and want to use its batch() method to batch the training dataset so that it can execute the model.fit() in GPU. According to its documentation, an example is as follows:

train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

Is the BATCH_SIZE in dataset.from_tensor_slices().batch() the same as the batch_size in the tf.keras modelt.fit()?

How should I choose BATCH_SIZE so that GPU has sufficient data to run efficiently and yet its memory is not overflown?


Solution

  • You do not need to pass the batch_size parameter in model.fit() in this case. It will automatically use the BATCH_SIZE that you use in tf.data.Dataset().batch().

    As for your other question : the batch size hyperparameter indeed needs to be carefully tuned. On the other hand, if you see OOM errors, you should decrease it until you do not get OOM (normally (but not necessarily) in this manner 32 --> 16 --> 8 ...). In fact you can try non-power of two batch sizes for the decrease purposes.

    In your case I would start with a batch_size of 2 an increase it gradually (3-4-5-6...).

    You do not need to provide the batch_size parameter if you use the tf.data.Dataset().batch() method.

    In fact, even the official documentation states this:

    batch_size : Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).