I am following the timeseries/LSTM tutorial for Tensorflow and struggle to understand what this line does as it is not really explained:
train_data.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
I tried to look up what the different modules do but I fail to understand the complete command and its effect on the dataset. Here is the entire tutorial: Click
It's an input pipeline definition based on the tensorflow.data
API.
Breaking it down:
(train_data # some tf.data.Dataset, likely in the form of tuples (x, y)
.cache() # caches the dataset in memory (avoids having to reapply preprocessing transformations to the input)
.shuffle(BUFFER_SIZE) # shuffle the samples to have always a random order of samples fed to the network
.batch(BATCH_SIZE) # batch samples in chunks of size BATCH_SIZE (except the last one, that may be smaller)
.repeat()) # repeat forever, meaning the dataset will keep producing batches and never terminate running out of data.
Notes:
cache()
, the second iteration of the dataset will load data from the cache in memory instead than the previous steps of the pipeline. This saves you some time if the data preprocessing is complex (but, for big datasets, this may be very heavy on your memory)BUFFER_SIZE
is the number of items in the shuffle buffer. The function fills the buffer and then randomly samples from it. A big enough buffer is needed for proper shuffling, but it's a balance with memory consumption. Reshuffling happens automatically at every epoch.Pay attention: this is a pipeline definition, so you'respecifying which operations are in the pipeline, not actually running them! The operations actually happen when you call next(iter(dataset))
, not before.