Search code examples
tensorflowtensorflow2.0tensorflow-datasets

how to batch with tf.data.Dataset.from_generator? Do i needto modify generator


I'm using the batch(8) function, it modifies the shape and adds batch dimension, but only getting one image per batch. Below is my code:-

import cv2
import numpy as np
import os
import tensorflow as tf
import random

folder_path = "./real/"
files = os.listdir(folder_path)

def get_image():
    index = random.randint(0,len(files)-1)
    img = cv2.imread(folder_path+files[index])
    img = cv2.resize(img,(128,128))
    img = img/255.
    #More complex transformation
    yield img

dset = tf.data.Dataset.from_generator(get_image,(tf.float32)).batch(8)

for img in dset:
    print(img.shape)
    break

The output still is (1, 128, 128, 3) even after using batch(8). Do I need to modify the generator to manually crate the batch? Also, how can it be wrapped in the generator in tensorflow so that it runs faster?


Solution

  • its because your yield only takes a single image which you should yield in a loop, here's an example

    def get_image():
       for file in files:
          img = cv2.imread(folder_path + file)
          img = cv2.resize(img, (128, 128))
          img = img / 255.
    
          yield img # Your supposed to yield in a loop
    
    dataset = tf.data.Dataset.from_generator(get_image, output_shapes=(128, 128), output_types=(tf.float32))
    
    next(iter(dataset.batch(8))).shape
    
    # TensorShape([8, 128, 128])