So I am writing a GAN in tensorflow, and need the discriminator and generator to be objects. Now I am having problems with creating the training dataset for the discriminator.
Currently the relevant part of my code looks like this:
self.dataset=tf.data.Dataset.from_tensor_slices((self.y_,self.x_)) #creates dataset
self.fake_dataset=tf.data.Dataset.from_tensor_slices((self.x_fake_)) #creates dataset
self.dataset=self.dataset.shuffle(buffer_size=BUFFER_SIZE) #shuffles
self.fake_dataset=self.fake_dataset.shuffle(buffer_size=BUFFER_SIZE) #shuffles
self.dataset=self.dataset.repeat().batch(self.batch_size) #batches
self.fake_dataset=self.fake_dataset.repeat().batch(self.batch_size) #batches
self.iterator=tf.data.Iterator.from_structure(self.dataset.output_types,self.dataset.output_shapes) #creates iterators
self.fake_iterator=tf.data.Iterator.from_structure(self.fake_dataset.output_types,self.fake_dataset.output_shapes) #creates iterators
self.x=self.iterator.get_next()
self.x_fake=self.fake_iterator.get_next()
self.dataset_init_op = self.iterator.make_initializer(self.dataset,name=self.name+'_dataset_init')
self.fake_dataset_init_op=self.fake_iterator.make_initializer(self.fake_dataset,name=self.name+'_dataset_init')
What I need is for the function to alternatively give one batch of self.x, followed by one batch of self.x_fake.
Is there an easy way to do this, or will I have to results to a counter and an if statement?
Not sure if I'm understanding exactly what you need, but if you want to get use the different iterators alternatively in the same call that would be defined at graph construction time, and so you could use Python logic to choose the iterator you need. For example:
def __init__(self):
# Make graph and iterators...
self._use_fake_batch = False
def next_batch(self):
iter = self.fake_iterator if self._use_fake_batch else self.iterator
self._use_fake_batch = not self._use_fake_batch
return iter.get_next()
Or without an additional variable, using itertools
:
from itertools import chain, repeat
def __init__(self):
# Make graph and iterators...
self._iterators = chain.from_iterable(repeat((self.iterator, self.fake_iterator)))
def next_batch(self):
return next(self._iterators).get_next()