Search code examples
pythondatabasetensorflowkerasbatchsize

How to use train_date.take(1) with Tensorflow


I'm working on a test with tensorflow. I have my dataset into two folders. I configured the batch_size, height and width for the train_data but then i can't see them with matplotlib or use it in the model.

#Import dataset
import pathlib
import os

data_dir = pathlib.Path(r'C:\Users\vion1\Ele\Engie\Exercices\DL\Pikachu\dataset')
image_count = len(list(data_dir.glob('*/*')))
print(image_count)
#374

batch_size = 32
img_height = 256
img_width = 256

train_data = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=42,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  )

class_names = train_data.class_names
print(train_data)
#Found 374 files belonging to 2 classes.
#Using 300 files for training.
#<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>

plt.figure(figsize=(10, 10))
for images, labels in train_data.take(1):
  for i in range(3):
    ax = plt.subplot(1, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.axis("off")

The error is :

InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.
     [[{{node decode_image/DecodeImage}}]] [Op:IteratorGetNext]

I think that train_date.take(1) doesn't take the file but i can't understand why and how to fix it, any idea?


Solution

  • The code which you have mentioned looks proper, the main reason for failure could be as per the error is that one or more file in your tf.data.Dataset does not belong to any of the mentioned file extension. To check the corrupted file you can refer the below code. Here I'm taking the example dataset mentioned in the document

    import matplotlib.pyplot as plt
    import numpy as np
    import os
    import PIL
    import tensorflow as tf
    
    from tensorflow import keras
    
    import pathlib
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    roses = list(data_dir.glob('roses/*'))
    

    Now, leet's check the unique filenames in the roses directory.

    file_names = [str(i) for i in roses]
    unique_files = set(i.split('.')[-1] for i in file_names)
    print(unique_files)
    
    Output:
    {'jpg'}
    

    In the output directory if you get any filetypes other than allowed filetypes, you need to recheck your data. Else you can follow this document for same procedure.