Search code examples
pythontensorflow2.0tfrecord

Resizing flattened images loaded from TFRecord files


Is it really necessary to store image dimensions information over TFRecord files? I'm currently working with a dataset composed of different scale images and did not stored width, length and number of channels information for the images i handled, now i'm facing a problem to resize them back to the original shape after loading the tfrecords in order to perform other preprocessing pipelines such as data augmentation.

# Create dataset
records_path = DATA_DIR + 'TFRecords/train_0.tfrecords'
dataset = tf.data.TFRecordDataset(filenames=records_path)

#Parse dataset
parsed_dataset = dataset.map(parsing_fn)

# Get iterator
iterator = tf.compat.v1.data.make_one_shot_iterator(parsed_dataset) 
image,label = iterator.get_next()

# Get the numpy array from tensor, convert to uint8 and plot image from array
img_array = image.numpy()
img_array = img_array.astype(np.uint8)
plt.imshow(img_array)
plt.show()

Output: TypeError: Invalid dimensions for image data

Between converting to uint8 i was supposed to resize image back to original shape? if so, how am i able to do if i did not stored the dimension information?

The pipeline below demonstrates one example of transformation that i wanted to apply to the image read from the tfrecord but i believe that these keras augmentation methods requires an appropriate resized array with defined dimensions to operate. ( i don't necessarily need to print the images).

def brightness(brightness_range, image):
    img = tf.keras.preprocessing.image.load_img(image)
    data = tf.keras.preprocessing.image.array_to_img(img)
    samples = expand_dims(data,0)
    print(samples.shape)
    datagen = tf.keras.preprocessing.image.ImageDataGenerator(brightness_range=brightness_range) 
    iterator = datagen.flow(samples,batch_size=1) 
    for i in range(9):
        pyplot.subplot(330+1+i)
        batch = iterator.next()
        image = batch[0].astype('uint8')
        pyplot.imshow(image)  
    pyplot.show()
brightness([0.2,1.0],DATA_DIR+"183350/5c3e30f1706244e9f199d5a0c5a5ec00d1cbf473.jpg")

Helper functions to write and read to TFRecord format

Converting to tfrecord:

def convert(image_paths, labels, out_path):
    # Args:
    # image_paths   List of file-paths for the images.
    # labels        Class-labels for the images.
    # out_path      File-path for the TFRecords output file.
    
    print("Converting: " + out_path)
    
    # Number of images. Used when printing the progress.
    num_images = len(image_paths)
    
    # Open a TFRecordWriter for the output-file.
    with tf.python_io.TFRecordWriter(out_path) as writer:
        
        # Iterate over all the image-paths and class-labels.
        for i in range(num_images):
          # Print the percentage-progress.
          print_progress(count=i, total=num_images-1)
          
          # Load the image-file using matplotlib's imread function.
          path = image_paths[i]
          img = imread(path)
          path = path.split('/')

          # Convert the image to raw bytes.
          img_bytes = img.tostring()

          # Get the label index  
          label = int(path[4])

          # Create a dict with the data we want to save in the
          # TFRecords file. You can add more relevant data here.
          data = \
              {
                  'image': wrap_bytes(img_bytes),
                  'label': wrap_int64(label)
              }

          # Wrap the data as TensorFlow Features.
          feature = tf.train.Features(feature=data)

          # Wrap again as a TensorFlow Example.
          example = tf.train.Example(features=feature)

          # Serialize the data.
          serialized = example.SerializeToString()
            
          # Write the serialized data to the TFRecords file.
          writer.write(serialized)

Parsing function

def parsing_fn(serialized):
    # Define a dict with the data-names and types we expect to
    # find in the TFRecords file.
    # It is a bit awkward that this needs to be specified again,
    # because it could have been written in the header of the
    # TFRecords file instead.
    features = \
        {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)
        }

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.io.parse_single_example(serialized=serialized,
                                             features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.io.decode_raw(image_raw, tf.uint8)
    
    # The type is now uint8 but we need it to be float.
    image = tf.cast(image, tf.float32)

    # Get the label associated with the image.
    label = parsed_example['label']
    # The image and label are now correct TensorFlow types.
    return image, label

Solution

  • You need to use tf.io.encode_jpeg when converting and tf.io.decode_jpeg when parsing. Then when you decode the jpeg it will have the dimensions preserved

    More specifically when encoding something like this

    image_bytes = tf.io.gfile.GFile(path, 'rb').read()
    image = tf.io.decode_jpeg(img_bytes, channels=3)
    image_bytes = tf.io.encode_jpeg(tf.cast(image, tf.uint8))
    

    And during parsing

    image = tf.io.decode_jpeg(image_raw)