Search code examples
tensorflowcomputer-visiontfrecord

Reading back a custom dataset TFRecords


I am trying to create a custom dataset in TFRecords for a CycleGAN model. The model requires a new type of dataset which is not available so I need to create one. I have a few JPG images of 256x256. Following this link, I created TFrecords file for my images, below code:

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# images input
def convert_to(images, output_directory, name):
    num_examples = images.shape[0]
    rows = images.shape[1]
    cols = images.shape[2]
    depth = 1

    filename = os.path.join(output_directory, name + '.tfrecords')
    print('Writing', filename)
    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        image = images[index]
        image_raw = images[index].tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'image_raw': _bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())

def read_image(file_name, images_path):
    image = skimage.io.imread(images_path + file_name)
    return image

def get_name(img_name):
    remove_ext = img_name.split(".")[0]
    name = remove_ext.split("_")
    return name[0]

images_path = "data/train/"
image_list = os.listdir(images_path)
images = []
for img_name in tqdm(image_list):
    tfrec_name = get_name(img_name)
    print(tfrec_name)
    img_data = read_image(img_name, images_path)
    convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")

Once the TFRecords are written, I use below code to read and decode it back

PHOTO_FILENAMES = tf.io.gfile.glob(str('data/cat_image_tfrecords/*.tfrecords'))

IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image_raw'])
    return image

def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

photo_ds = load_dataset(PHOTO_FILENAMES, labeled=False).batch(1)
example_photo = next(iter(photo_ds))

The decoding does not work since I get the below error in the last line

InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got unknown format starting with ']LBhXKeQFVC4S=/1'
     [[{{node DecodeJpeg}}]]

Clearly there is a mismatch between how I am writing the TFRecord in convert_to function and how I am reading it back in read_tfrecord function. But I am not sure how to fix it. Any suggestion?

EDIT

@sebastian-sz solution solves the problem. I tried to display one image like below

import matplotlib.pyplot as plt
plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0])

It displays the image but I see that the color/light of the image is much darker than the original image. Not sure what is going on though. Attached screenshot. Original image at the bottom.

enter image description here


Solution

  • There are few issues in your code:

    Parameter Issue:

    The problem is in a function convert_to, in more detail, the function expects a list of images:

    (...)
    image = images[index]
    (...)
    

    However, you are passing a single image
    convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
    Hence, later the image is of shape (for example) 224, 3 which is an invalid image shape.

    To fix this, change convert_to to accept a single image.

    Serialization Issue

    Skimage .tobytes seems to be incompatible. Consider using tf.io.encode_jpeg(image).numpy() to obtain image bytes.

    Full code

    I was able to save and read sample image with the following code:

    # Saving
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    # images input
    def convert_to(image, output_directory, name):
        rows = image.shape[0]
        cols = image.shape[1]
        depth = 1
    
        filename = os.path.join(output_directory, name + '.tfrecords')
        print('Writing', filename)
        writer = tf.compat.v1.python_io.TFRecordWriter(filename)
        print(image.shape)
        image_raw = tf.io.encode_jpeg(image).numpy()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'image_raw': _bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())
    
    def read_image(file_name, images_path):
        image = skimage.io.imread(images_path + file_name)
        return image
    
    def get_name(img_name):
        remove_ext = img_name.split(".")[0]
        name = remove_ext.split("_")
        return name[0]
    
    images_path = "data/train/"
    image_list = os.listdir(images_path)
    for img_name in tqdm(image_list):
        tfrec_name = get_name(img_name)
        print(tfrec_name)
        img_data = read_image(img_name, images_path)
        convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
    
    
    # Loading:
    PHOTO_FILENAMES = tf.io.gfile.glob(str('data/cat_image_tfrecords/*.tfrecords'))
    
    IMAGE_SIZE = [256, 256]
    
    def decode_image(image):
        image = tf.image.decode_jpeg(image, channels=3)
        image = (tf.cast(image, tf.float32) / 127.5) - 1
        # Changed this from reshape 
        # Consider reshape if all your images have the same shape
        image = tf.image.resize(image, IMAGE_SIZE)
        return image
    
    def read_tfrecord(example):
        tfrecord_format = {
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'depth': tf.io.FixedLenFeature([], tf.int64),
        'image_raw': tf.io.FixedLenFeature([], tf.string),
        }
        example = tf.io.parse_single_example(example, tfrecord_format)
        image = decode_image(example['image_raw'])
        return image
    
    def load_dataset(filenames, labeled=True, ordered=False):
        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
        return dataset
    
    photo_ds = load_dataset(PHOTO_FILENAMES, labeled=False).batch(1)
    example_photo = next(iter(photo_ds))