Search code examples
pythontensorflowtfrecord

About shapes when reading data from tfrecords


I'm going to read 'image'(2000) and 'landmarks'(388) from tfrecords.

this is the part of code.

filename_queue = tf.train.string_input_producer([savepath])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.string), 'img_raw':tf.FixedLenFeature([], tf.string), })

image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [224, 224, 3])
image = tf.cast(image, tf.float32)

label = tf.decode_raw(features['label'], tf.float64) # problem is here
label = tf.cast(label, tf.float32)
label = tf.reshape(label, [388])

error is

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 291 values, but the requested shape has 388.

when I change 'float64' to 'float32':

 label = tf.decode_raw(features['label'], tf.float32) # problem is here

 #Error: InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 582 values, but the requested shape has 388

or to 'float16':

label = tf.decode_raw(features['label'], tf.float16) # problem is here

#Error: InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 1164 values, but the requested shape has 388

And here is how I made tfrecords:(to make it simple, I simplify some code)

writer = tf.python_io.TFRecordWriter(savepath)
for i in range(number_of_images):
    img = Image.open(ImagePath[i])  # load one image from path
    landmark = landmark_read_from_csv[i]  # shape of landmark_read_from_csv is (number_of_images, 388)
    example = tf.train.Example(features=tf.train.Features(feature={
    "label": tf.train.Feature(bytes_list=tf.train.BytesList(value=[landmark.tobytes()])),
    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))}))
    writer.write(example.SerializeToString())
writer.close()

I have 3 questions:

  1. why does shape change after data type was changed?
  2. how to choose a proper data type? (because sometimes I can successfully decode image with 'tf.float64', but sometimes 'tf.uint8' with different data set)
  3. Is there any problem with the code of creating tfrecords?

Solution

  • I have recently come across a very similar issue and from my personal experience I am pretty confident I was able to infer the answer to what you are asking even though I am not 100% sure.

    1. List item the shape changes because different data types have different lengths when encoded as byte lists and since a float16 has half the length of a float32 the same byte list can either be read as a sequence of n float32 values or as twice as many float16 values. In other words the byte list you are trying to decode doesn't change when you change data type but what does change is the partitions you make of this array list.

    2. You should check the data type of the the data you use to generate the tfrecord file with and use the same datatype to decode the byte_list when reading it (you can check the data type of a numpy array with the .dtype attribute).

    3. None that I can see but I may be wrong.