Search code examples
pythontensorflowgenerative-adversarial-network

Tensorflow efficient Cyclegan history pool


in the CycleGAN paper there is mention of history pool for discriminator. Thus we keep last e.g. 50 samples from generator and feed them to discriminator. Without history it's quite simple, we can utilize tf.data.Dataset and iterators to plug data into the network. But with the history pool, I did not get how to use the tf.data.Dataset api. The code inside the training loop looks something like

fx, fy = sess.run(model_ops['fakes'], feed_dict={
    self.cur_x: cur_x,
    self.cur_y: cur_y,
})

cur_x, cur_y = sess.run([self.X_feed.feed(), self.Y_feed.feed()])
feeder_dict = {
    self.cur_x: cur_x,
    self.cur_y: cur_y,
    self.prev_fake_x: x_pool.query(fx, step),
    self.prev_fake_y: y_pool.query(fy, step),
}
# self.cur_x, self.cur_y, self.prev_fake_x, self.prev_fake_y are just placeholders
# x_pool and y_pool are simple wrappers for random sampling from the history pool and saving new images to the pool
for _ in range(dis_train):
    sess.run(model_ops['train']['dis'], feed_dict=feeder_dict)
for _ in range(gen_train):
    sess.run(model_ops['train']['gen'], feed_dict=feeder_dict)

What bothers me about the code is it's ineffeciency, e.g. no possibility to preload next batch during training as with tf.data API's prefetch, but I don't see any way to use the tf.data API. Does it offer some kind of history pooling which could I use with prefetching and generally optimizing the data loading model? Also, similar problem arises when I have some ratio between discriminators train op and generators train op. If I want, for instance, run 2 steps of generators training operation per 1 step of discriminators, can it be done using same data? Because with tf.data API, new sample is drawn from iterator every time the sess.run is called.

Is there any way to implement that properly and efficiently?


Solution

  • So, I found out there is the history pooling already implemented in TFGAN tensorflow contrib repository.