Search code examples
pythonimage-processingtensorflow2.0tfrecord

Only a single image is extracted from TFrecords


As the titled state, only 1 image+label is loaded from my tfrecord files. There is a variable number of images/labels in each tfrecord, but always at least 8 pairs. I am using TF version: 2.4.1

Possibly related I am getting this warning:

WARNING:tensorflow:AutoGraph could not transform <function parse_tfr_element at 0x7fbb7db99160> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output. Cause: module 'gast' has no attribute 'Index' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Below are the functions I use to load the test data. Any help is appreciated.

def parse_tfr_element(element):

    data = {
      'height': tf.io.FixedLenFeature([], tf.int64),
      'width':tf.io.FixedLenFeature([], tf.int64),
      'depth':tf.io.FixedLenFeature([], tf.int64),
      'raw_label':tf.io.FixedLenFeature([], tf.string),#tf.string = bytestring (not text string)
      'raw_image' : tf.io.FixedLenFeature([], tf.string),#tf.string = bytestring (not text string)
    }


    content = tf.io.parse_single_example(element, data)

    height = content['height']
    width = content['width']
    depth = content['depth']
    raw_label = content['raw_label']
    raw_image = content['raw_image']


    #get our 'feature'-- our image -- and reshape it appropriately
    feature = tf.io.parse_tensor(raw_image, out_type=tf.float16)
    feature = tf.reshape(feature, shape=[height,width,depth])
    label = tf.io.parse_tensor(raw_label, out_type=tf.int8)
    label = tf.reshape(label, shape=[height,width])
    return (feature, label)

def get_batched_dataset(filenames):
    option_no_order = tf.data.Options()
    option_no_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.with_options(option_no_order)
    dataset = dataset.map(parse_tfr_element, num_parallel_calls=AUTO)

    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True) 


    return dataset

Solution

  • It turns out the problem was silly and had nothing to do with the functions I posted in the question. The problem was I had the following value for steps_per_epoch fed into the model.

    steps_per_epoch = len(training_filenames)  // BATCH_SIZE
    

    since the files hold multiple cases, len(training_filenames) needs to be multiplied by the number of cases in each file.

    steps_per_epoch = len(training_filenames) * images_in_file  // BATCH_SIZE