Search code examples
pythontensorflowtensorboard

How to get the preivous batch after running Iteration.get_next() in tensorflow?


Recently, I want to implement GAN model and use tf.Dataset and Iterator to read face images as the training data.

the dataset and iterator object's code is:

self.dataset = tf.data.Dataset.from_tensor_slices(convert_to_tensor(self.data_ob.train_data_list, dtype=tf.string))
self.dataset = self.dataset.map(self._parse_function)
#self.dataset = self.dataset.shuffle(buffer_size=10000)
self.dataset = self.dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))

self.iterator = tf.data.Iterator.from_structure(self.dataset.output_types, self.dataset.output_shapes)
self.next_x = self.iterator.get_next()

My new GAN model is:

self.z_mean, self.z_sigm = self.Encode(self.next_x)
self.z_x = tf.add(self.z_mean, tf.sqrt(tf.exp(self.z_sigm))*self.ep)
self.x_tilde = self.generate(self.z_x, reuse=False)
#the feature
self.l_x_tilde, self.De_pro_tilde = self.discriminate(self.x_tilde)

#for Gan generator
self.x_p = self.generate(self.zp, reuse=True)
# the loss of dis network
self.l_x,  self.D_pro_logits = self.discriminate(self.next_x, True)

So, the problem is that I use the self.next_x as the input tensor twice. The dataset for every time is different. So, how to solve this problem to keep the first batch for reusing?


Solution

  • What I use in my code is the following, where x and y_true are placeholders. Not sure if there are any more efficient implementations.

    images, labels = session.run(next_element)
    batch_accuracy = session.run(accuracy, feed_dict={x: images, y_true: labels, keep_prop: 1.0})
    batch_predicted_probabilities = session.run(y_pred, feed_dict={x: images, y_true: labels, keep_prop: 1.0})
    

    I'm currently trying to use tf.placeholder_with_default instead of normal placeholders for x and y_true to check if it gives better performance in my project. Will edit my answer to let you know, if I manage to get any results soon :).

    Edit: I switched to placeholder_with_default and it gave no noticeable speed improvement per batch, at least in the way I'm measuring it.