Search code examples
pythontensorflowtensorflow2.0tf.kerastensorflow-datasets

create tensorflow dataset from list_files


I am trying to create tensroflow dataset :

path_imgs = ('./images/train/*.jpg')
path_masks =('./masks/train/*.jpg'

images = tf.data.Dataset.list_files(path_imgs, shuffle=False)
masks = tf.data.Dataset.list_files(path_masks, shuffle=False)

dataset = tf.data.Dataset.from_tensor_slices((tf.constant(path_imgs),
                                              tf.constant(path_masks)))

and I am receiving:

Unbatching a tensor is only supported for rank >= 1


Solution

  • Try something like this:

    import tensorflow as tf
    
    path_imgs = ('/content/images/*.jpg')
    path_masks = ('/content/masks/*.jpg')
    
    images = tf.data.Dataset.list_files(path_imgs, shuffle=False)
    masks = tf.data.Dataset.list_files(path_masks, shuffle=False)
    
    ds = tf.data.Dataset.zip((images, masks))
    
    def load_data(image_path, mask_path):
      return tf.image.decode_image(tf.io.read_file(image_path)), tf.image.decode_image(tf.io.read_file(mask_path))
    
    ds = ds.map(load_data)
    
    for x, y in ds:
      print(x.shape, y.shape)
    
    (100, 100, 3) (100, 100, 3)
    (100, 100, 3) (100, 100, 3)
    

    Note, however, what the docs state regarding tf.data.Dataset.list_files:

    The file_pattern argument should be a small number of glob patterns. If your filenames have already been globbed, use Dataset.from_tensor_slices(filenames) instead, as re-globbing every filename with list_files may result in poor performance with remote storage systems.

    Splitting also works:

    train_ds, test_ds = tf.keras.utils.split_dataset(ds, left_size=0.5, right_size=0.5, shuffle=True, seed=123)
    

    Here is the notebook to try it out.