Search code examples
pythontensorflowforecasting

forecasting tensorflow confused with usage of repeat


I came across this notebook that covers forecasting. I got it through this article.

I am confused about the 2nd and 4th line from below

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.cache().shuffle(buffer_size).batch(batch_size).repeat()

val_data = tf.data.Dataset.from_tensor_slices((x_vali, y_vali))
val_data = val_data.batch(batch_size).repeat()

I understand that we are trying to shuffle our data as we dont want to feed data to our model in the serial order. On additional reading I realized that it is better to have buffer_size same as the size of the dataset. But I am not sure what repeat is doing in this case. Could someone explain what is being done here and what is the function of repeat?

I also looked at this page and saw below text but still not clear.

The following methods in tf.Dataset :

repeat( count=0 ) The method repeats the dataset count number of times.
shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf.Dataset.
batch(batch_size,drop_remainder=False) Creates batches of the dataset with batch size given as batch_size which is also the length of the batches.

Solution

  • The repeat call with nothing passed to the count param makes this dataset repeat infinitely.

    In python terms, Datasets are a subclass of python iterables. If you have an object ds of type tf.data.Dataset, then you can execute iter(ds). If the dataset was generated by repeat(), then it will never run out of items, i.e., it will never throw a StopIteration exception.

    In the notebook you referenced, the call to tf.keras.Model.fit() is passed an argument of 100 to the param steps_per_epoch. This means that the dataset should be infinitely repeating, and Keras will pause training to run validation every 100 steps.

    tldr: leave it in.

    https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/tensorflow/python/data/ops/dataset_ops.py#L134-L3445

    https://docs.python.org/3/library/exceptions.html