Search code examples
pythontensorflowkerastensorflow-datasets

Keras - ImageDataGenerator How to get batch of labels?


My model require two separate input: images and labels. But with ImageDataGenerator flow_from_dataframe I can get only full batch with both images and labels. What should I do?


Solution

  • The issue is that flow_from_dataframe can seemingly only accept one column from a dataframe as x. You can wrap flow_from_dataframe in tf.data.Dataset.from_generator and use tf.data.Dataset.map to get your labels also as inputs. Here is an example using flow_from_directory:

    import matplotlib.pyplot as plt
    
    BATCH_SIZE = 32
    
    flowers = tf.keras.utils.get_file(
        'flower_photos',
        'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
        untar=True)
    
    img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
    
    ds = tf.data.Dataset.from_generator(
        lambda: img_gen.flow_from_directory(flowers, batch_size=BATCH_SIZE, shuffle=True),
        output_types=(tf.float32, tf.float32))
    
    ds = ds.map(lambda x, y: ((x, y), y))
    
    for x, y in ds.take(1):
      input1, input2 = x
      print(input1.shape, input2.shape)
    
    Found 3670 images belonging to 5 classes.
    (32, 256, 256, 3) (32, 5)
    

    Or you can use tf.keras.utils.image_dataset_from_directory:

    import tensorflow as tf
    import pathlib
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    batch_size = 32
    
    ds = tf.keras.utils.image_dataset_from_directory(
      data_dir,
      validation_split=0.2,
      subset="training",
      seed=123,
      image_size=(180, 180),
      batch_size=batch_size)
    
    ds = ds.map(lambda x, y: ((x, y), y))
    
    for x, y in ds.take(1):
      input1, input2 = x
      print(input1.shape, input2.shape)