Search code examples
pythontensorflowdatasetdouble

How can I split output from the tf.data.Dataset?


I have some .npy data whose shapes are [150~180, 480~512, 480~512], extracted from MRI image.

So I used some functions to refine the dataset and convert them as tf.data.Dataset type

train_dataset = tf.data.Dataset.from_tensor_slices((list_image,list_label))
train_dataset = train_dataset.shuffle(NUM_TRN)
train_dataset = train_dataset.batch(NUM_BATCH_NPY)
train_dataset = train_dataset.map(
    lambda x,y: tf.py_function(load_dataset, inp=[x,y], Tout=[tf.float16, tf.float16]))


description:
1) list_image & list_label is the lists of .npys
---- [000_image.npy,..., 099_image.npy], [000_label.npy,.,.., 099_label.npy]

2) NUM_TRN is the number of total dataset, and it is used to shuffle all dataset
---- 100 (The number of *_image.npy)

3) NUM_BATCH_NPY is the number of .npys that will be extracted simultaneously
---- If NUM_BATCH_NPY is 3, three sets of .npys will be extracted
---- [000_image/label.npy], [001_image/label.npy], [002_image/label.npy]

4) The function 'load_dataset' will extract arrays from the .npys,
   refine the extracted .npys and stack them along 0-axis.
---- 000_image.npy->(170,360,360,1), 000_label.npy->(170,)
---- 001_image.npy->(150,360,360,1), 001_label.npy->(150,)
---- 002_image.npy->(163,360,360,1), 002_label.npy->(163,)
---- output shape of the dataset will be ((483,360,360,1),(483,))

As mentioned above, the arrays will be extracted per image. The questions is, how can I split this extracted dataset-type object into NUM_TRAIN_BATCH=128 slices?

---- extracted dataset (483,360,360,1)->(128,:),(128,:), ...


Solution

  • I would extract batches and create new datasets:

    datasets = []
    train_dataset = train_dataset.batch(NUM_TRAIN_BATCH)
    for set in train_dataset:
      set = tf.data.Dataset.from_tensor_slices(set)
      datasets.append(set)