Search code examples
pythontensorflowkerastensorflow-datasetstf.keras

How to attach or get filenames from MapDataset from image_dataset_from_directory() in Keras?


I am training convolutional autoencoder and I have this code for loading data (images):

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    'path/to/images',
    image_size=image_size
)
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

def adjust_inputs(images, labels):
    return normalization_layer(images), normalization_layer(images)

normalized_train_ds = train_ds.map(adjust_inputs)

As I don't need class labels but images itself as Y, I am mapping function adjust_inputs to dataset. But now when I try to access attribute filenames, I get error: AttributeError: 'MapDataset' object has no attribute 'filenames'. That is logical, because MapDataset is not Dataset.

How would I attach or get filenames of loaded images that are in my Dataset?

I am really surprised that there is not an easier interface for this, this looks like pretty common thing.


Solution

  • Just in case you want to add the filepaths as part of your dataset:

    import tensorflow as tf
    import pathlib
    import matplotlib.pyplot as plt
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    batch_size = 32
    train_ds = tf.keras.utils.image_dataset_from_directory(data_dir, shuffle=False, batch_size=batch_size)
    
    normalization_layer = tf.keras.layers.Rescaling(1./255)
    def change_inputs(images, labels, paths):
      x = normalization_layer(images)
      return x, x, tf.constant(paths)
    
    normalized_ds = train_ds.map(lambda images, labels: change_inputs(images, labels, paths=train_ds.file_paths))
    
    images, images, paths = next(iter(normalized_ds.take(1)))
    
    image = images[0]
    path = paths[0]
    print(path)
    plt.imshow(image.numpy())
    
    Found 3670 files belonging to 5 classes.
    tf.Tensor(b'/root/.keras/datasets/flower_photos/daisy/100080576_f52e8ee070_n.jpg', shape=(), dtype=string)
    <matplotlib.image.AxesImage at 0x7f9b113d1a10>
    

    enter image description here

    You will have to just make sure to you use the same batch size for the paths.