Search code examples
tensorflowqueueepoch

Determining the Epoch Number with tf.train.string_input_producer in tensorflow


I have some doubts on how tf.train.string_input_producer works. So suppose I fed filename_list as an input parameter to the string_input_producer. Then, according to the documentation https://www.tensorflow.org/programmers_guide/reading_data, this will create a FIFOQueue, where I can set epoch number, shuffle the file names and so on. Therefore, in my case, I have 4 file names ("db1.tfrecords", "db2.tfrecords"...). And I used tf.train.batch to feed the network batch of images. In addition, each file_name/database, contain a set of images for one person. The second database is for the second person and so on. So far I have the following code:

tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"),
                          (common + "P21_db.tfrecords")]

filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue')
reader = tf.TFRecordReader()

key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
    serialized_example,
    # Defaults are not specified since both keys are required.
    features={
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'image_raw': tf.FixedLenFeature([], tf.string),
        'annotation_raw': tf.FixedLenFeature([], tf.string)
    })

image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)

image = tf.reshape(image, [height, width, 3])

annotation = tf.cast(features['annotation_raw'], tf.string)

min_after_dequeue = 100
num_threads = 4
capacity = min_after_dequeue + num_threads * batch_size
label_batch, images_batch = tf.train.batch([annotation, image],
                                                        shapes=[[], [112, 112, 3]],
                                                        batch_size=batch_size,
                                                        capacity=capacity,
                                                        num_threads=num_threads)

Finally, when trying to view out the reconstructed image at the output of the autoencoder, I got the first the images from the 1st database, then I start viewing images from the second database and so on.

My question: How can i know if I'm within the same epoch? And if I'm within the sane epoch, how can i merge a batch of images from all the file_names that I have?

Finally, I tried to print out the value of the epoch by evaluating the local variable within the Session as follows:

epoch_var = tf.local_variables()[0]

Then:

with tf.Session() as sess:
    print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y.

Any help is much appreciated!!


Solution

  • So what I figured out is that using tf.train.shuffle_batch_join solves my issue as it starts shuffling images from different data sets. In other words, every batch is now containing images from all the datasets/file_names. Here is an example:

    def read_my_file_format(filename_queue):
        reader = tf.TFRecordReader()
        key, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(
            serialized_example,
            # Defaults are not specified since both keys are required.
            features={
                'height': tf.FixedLenFeature([], tf.int64),
                'width': tf.FixedLenFeature([], tf.int64),
                'image_raw': tf.FixedLenFeature([], tf.string),
                'annotation_raw': tf.FixedLenFeature([], tf.string)
            })
    
        # This is how we create one example, that is, extract one example from the database.
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        # The height and the weights are used to
        height = tf.cast(features['height'], tf.int32)
        width = tf.cast(features['width'], tf.int32)
    
        # The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the
        # height and the weight to restore the original image back.
        image = tf.reshape(image, [height, width, 3])
    
        annotation = tf.cast(features['annotation_raw'], tf.string)
        return annotation, image
    
    def input_pipeline(filenames, batch_size, num_threads, num_epochs=None):
        filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epoch, shuffle=False,
                                                        name='queue')
        # Therefore, Note that here we have created num_threads readers to read from the filename_queue.
        example_list = [read_my_file_format(filename_queue=filename_queue) for _ in range(num_threads)]
        min_after_dequeue = 100
        capacity = min_after_dequeue + num_threads * batch_size
        label_batch, images_batch = tf.train.shuffle_batch_join(example_list,
                                                                shapes=[[], [112, 112, 3]],
                                                                batch_size=batch_size,
                                                                capacity=capacity,
                                                                min_after_dequeue=min_after_dequeue)
        return label_batch, images_batch, example_list
    
    label_batch, images_batch, input_ann_img = \
        input_pipeline(tfrecords_filename_seq, batch_size, num_threads, num_epochs=num_epoch)
    

    And now this is going to create a number of readers to read from the FIFOQueue, and after each reader will have a different decoder. Finally, after decoding the images, they will fed into another Queue that is created after calling tf.train.shuffle_batch_join to feed the network a batch of images.