Search code examples
pythontensorflowtensorflow-datasets

Tensorflow batch resize dimension


I got this dimension of data after batch

tf.Tensor(
[[[  2 436 381 ... 416 333   3]]

 [[  2 651 374 ... 654 370   3]]

 [[  2 743 357 ... 771 358   3]]

 ...

 [[  2 594 432 ... 552 425   3]]

 [[  2 820 409 ... 886 438   3]]

 [[  2 734 397 ... 825 330   3]]], shape=(64, 1, 34), dtype=int64) tf.Tensor(
[[[  2 335 395 ... 281 405   3]]

 [[  2 542 379 ... 512 370   3]]

 [[  2 676 356 ... 696 354   3]]

 ...

 [[  2 733 411 ... 718 403   3]]

 [[  2 828 389 ... 883 407   3]]

 [[  2 774 376 ... 850 316   3]]], shape=(64, 1, 34), dtype=int64)

However, I want the shape of the batch to be like (64,34). I tried the reshape after batch, but it is not working. This is how the batch created.

BATCH_SIZE = 64
def prepare(ds):
  src, trg = tf.split(ds, num_or_size_splits = 2, axis=1)
  return srcs, trgs

def make_batches(ds):
   return (
      ds
      .cache()
      .shuffle(BUFFER_SIZE)
      .batch(BATCH_SIZE,num_parallel_calls=tf.data.experimental.AUTOTUNE)
      .map(prepare,num_parallel_calls=tf.data.experimental.AUTOTUNE)
      .prefetch(tf.data.experimental.AUTOTUNE))

train_batches = make_batches(train_examples)

Solution

  • Change your prepare method to this:

    def prepare(x):
      srcs, trgs = tf.split(x, num_or_size_splits = 2, axis=1)
      return tf.squeeze(srcs, axis=1), tf.squeeze(trgs, axis=1)
    

    And you should have your desired output.