Search code examples
pythontensorflowkerastensorflow2.0tensorflow-datasets

Tensorflow dataset, how to concatenate/repeat data within each batch?


If I have the following dataset: dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])

When I use a batch_size=2, I would get [[1,2], [3,4], [5,6]].

However, I would like to get the following output: [[1,2,1,2], [3,4,3,4], [5,6,5,6]]

Basically, I want to repeat the batch dimension by 2x and use this as a new batch. Obviously, this is a toy example. In a real case, if I have a batch of size (64, 300), I would like to make a batch of (128, 300).


Solution

  • You can do it by defining a map function

    def double_input(x):
      x = tf.concat([x,x],axis=0)
    
      return x
    
    dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
    dataset = dataset.batch(2)
    dataset = dataset.map(double_input)
    
    for x in dataset.take(-1):
      print(x)
    
    >>>tf.Tensor([1 2 1 2], shape=(4,), dtype=int32)
    >>>tf.Tensor([3 4 3 4], shape=(4,), dtype=int32)
    >>>tf.Tensor([5 6 5 6], shape=(4,), dtype=int32)