I have multiple TFRecord
files, which all hold a specific timeframe of my data. Containing data points are consecutive inside each file but are not consecutive across files. As part of my input pipeline, i am using tf.contrib.data.sliding_window_batch
in order to process a window of data points as following:
filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser_fn, num_parallel_calls=6)
dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)
How can i prevent the window to be spanned across data points from different files?
An alternative would be to create the batches on each file independently, and interleave the results:
def interleave_fn(filename):
dataset = dataset.map(parser_fn, num_parallel_calls=6)
dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window
filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.interleave(interleave_fn, num_parallel_calls=...)
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=1000000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)
This is probably more performant as it bypasses the filter step.