Search code examples
tensorflowsliding-windowtfrecord

tensorflow - Input pipeline with multiple TFRecord files + tf.contrib.data.sliding_window_batch()


I have multiple TFRecord files, which all hold a specific timeframe of my data. Containing data points are consecutive inside each file but are not consecutive across files. As part of my input pipeline, i am using tf.contrib.data.sliding_window_batch in order to process a window of data points as following:

filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
dataset = tf.data.TFRecordDataset(filenames)

dataset = dataset.map(parser_fn, num_parallel_calls=6)
dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window
dataset = dataset.map(lambda x: prepare_fn(x, y, z))
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
dataset = dataset.prefetch(2)

How can i prevent the window to be spanned across data points from different files?


Solution

  • An alternative would be to create the batches on each file independently, and interleave the results:

    def interleave_fn(filename):
      dataset = dataset.map(parser_fn, num_parallel_calls=6)
      dataset = dataset.map(preprocessing_fn, num_parallel_calls=6)
      dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=y + z)) # sliding window
    
    filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.interleave(interleave_fn, num_parallel_calls=...)
    dataset = dataset.map(lambda x: prepare_fn(x, y, z))
    dataset = dataset.shuffle(buffer_size=1000000)
    dataset = dataset.batch(32)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(2)
    

    This is probably more performant as it bypasses the filter step.