Search code examples
pythonkerasdeep-learninggenerative-adversarial-network

Training GAN in keras with .fit_generator()


I have been training a conditional GAN architecture similar to Pix2Pix with the following training-loop:

for epoch in range(start_epoch, end_epoch):
    for batch_i, (input_batch, target_batch) in enumerate(dataLoader.load_batch(batch_size)):
                fake_batch= self.generator.predict(input_batch)

                d_loss_real = self.discriminator.train_on_batch(target_batch, valid)
                d_loss_fake = self.discriminator.train_on_batch(fake_batch, invalid)
                d_loss = np.add(d_loss_fake, d_loss_real) * 0.5

                g_loss = self.combined.train_on_batch([target_batch, input_batch], [valid, target_batch])

Now this works well, but it is not very efficient as the dataloader quickly becomes a bottleneck time-wise. I have looked into the .fit_generator() function that keras provides, which allows the generator to run in a worker thread and runs much faster.

self.combined.fit_generator(generator=trainLoader,
                                    validation_data=evalLoader
                                    callbacks=[checkpointCallback, historyCallback],
                                    workers=1,
                                    use_multiprocessing=True)

It took me some time to see that this was incorrect, I wasn't training my generator and discriminator separately anymore and the discriminator wasn't being trained at all since it it set to trainable = False in the combined model, essentially ruining any kind of adversarial loss, and I might as well train my generator by itself with MSE.

Now my question is if there is some work around, such as training my discriminator inside a custom callback, which is triggered each batch of the .fit_generator() method? It is possible to implement to create custom callbacks, like this for example:

class MyCustomCallback(tf.keras.callbacks.Callback):
  def on_train_batch_end(self, batch, logs=None):
    discriminator.train_on_batch()

Another possibility would be to parallelise the original training loop, but I am afraid that I have no time to do that right now.


Solution

  • Update: There are built in enqueuers for this:

    You can check a quick way to use them in this answer: https://stackoverflow.com/a/59214794/2097240


    Old answer:

    I created this parallelized iterator exactly for that purpose. I use it in my trainings;

    This is how you use it:

    for epoch, batchIndex, originalBatchIndex, xAndY in ParallelIterator(
                                           generator, 
                                           epochs, 
                                           shuffle_bool, 
                                           use_on_epoch_end_from_generator_bool,
                                           workers = 8, 
                                           queue_size=10):
        #loop content
        x_train_batch, y_train_batch = xAndY
        model.train_on_batch(x_train_batch, y_train_batch)
    
    
    

    The generator there should be your dataloader, but it needs to be a keras.utils.Sequence, not just a yield generator.

    But it's not very complicated to adapt if you need. (I just don't know if it will parallelize properly, though, I don't know if yield loops can be properly parallelized)
    In the iterator definition below, you should replace:

    • len(keras_sequence) with steps_per_epoch
    • keras_sequence[i] with next(keras_sequence)
    • use_on_epoch_end = False

    And this is the iterator definition:

    
    import multiprocessing.dummy as mp
    
    #A generator that wraps a Keras Sequence and simulates a `fit_generator` behavior for custom training loops
    #It will also work with any iterator that has `__len__` and `__getitem__`.    
    def ParallelIterator(keras_sequence, epochs, shuffle, use_on_epoch_end, workers = 4, queue_size = 10):
    
        sourceQueue = mp.Queue()                     #queue for getting batch indices
        batchQueue = mp.Queue(maxsize = queue_size)  #queue for getting actual batches 
        indices = np.arange(len(keras_sequence))     #array of indices to be shuffled
    
        use_on_epoch_end = 'on_epoch_end' in dir(keras_sequence) if use_on_epoch_end == True else False
        batchesLeft = 0
    
    #     printQueue = mp.Queue()                      #queue for printing messages
    #     import threading
    #     screenLock = threading.Semaphore(value=1)
    #     totalWorkers= 0
    
    #     def printer():
    #         nonlocal printQueue, printing
    #         while printing:
    #             while not printQueue.empty():
    #                 text = printQueue.get(block=True)
    #                 screenLock.acquire()
    #                 print(text)
    #                 screenLock.release()
    
        #fills the batch indices queue (called when sourceQueue is empty -> a few batches before an epoch ends)
        def fillSource():
            nonlocal batchesLeft
    
    #         printQueue.put("Iterator: fill source - source qsize = " + str(sourceQueue.qsize()))
            if shuffle == True:
                np.random.shuffle(indices)
    
            #puts the indices in the indices queue
            batchesLeft += len(indices)
    #         printQueue.put("Iterator: batches left:" + str(batchesLeft))
            for i in indices:
                sourceQueue.put(i)
    
        #function that will load batches from the Keras Sequence
        def worker():
            nonlocal sourceQueue, batchQueue, keras_sequence, batchesLeft
    #         nonlocal printQueue, totalWorkers
    #         totalWorkers += 1
    #         thisWorker = totalWorkers
    
            while True:
    #             printQueue.put('Worker: ' + str(thisWorker) + ' will try to get item')
                index = sourceQueue.get(block = True) #get index from the queue
    #             printQueue.put('Worker: ' + str(thisWorker) + ' got item ' +  str(index) + " - source q size = " + str(sourceQueue.qsize()))
    
                if index is None:
                    break
    
                item = keras_sequence[index] #get batch from the sequence
                batchesLeft -= 1
    #             printQueue.put('Worker: ' + str(thisWorker) + ' batches left ' + str(batchesLeft))
    
                batchQueue.put((index,item), block=True) #puts batch in the batch queue
    #             printQueue.put('Worker: ' + str(thisWorker) + ' added item ' + str(index) + ' - queue: ' + str(batchQueue.qsize()))
    
    #         printQueue.put("hitting end of worker" + str(thisWorker))
    
    #       #printing pool that will print messages from the print queue
    #     printing = True
    #     printPool = mp.Pool(1, printer)
    
        #creates the thread pool that will work automatically as we get from the batch queue
        pool = mp.Pool(workers, worker)    
        fillSource()   #at this point, data starts being taken and stored in the batchQueue
    
        #generation loop
        for epoch in range(epochs):
    
            #if not waiting for epoch end synchronization, always keeps 1 epoch filled ahead
            if (use_on_epoch_end == False):
                if epoch + 1 < epochs: #only fill if not last epoch
                    fillSource()
    
            for batch in range(len(keras_sequence)):
    
                #if waiting for epoch end synchronization, wait for workers to have no batches left to get, then call epoch end and fill
                if use_on_epoch_end == True:
                    if batchesLeft == 0:
                        keras_sequence.on_epoch_end()
                        if epoch + 1 < epochs:  #only fill if not last epoch
                            fillSource()
                        else:
                            batchesLeft = -1   #in the last epoch, prevents from calling epoch end again and again
    
                #yields batches for the outside loop that is using this generator
                originalIndex, batchItems = batchQueue.get(block = True)
                yield epoch, batch, originalIndex, batchItems
    
    
    #         print("iterator epoch end")
    #     printQueue.put("closing threads")
    
        #terminating the pool - add None to the queue so any blocked worker gets released
        for i in range(workers):
            sourceQueue.put(None)
        pool.terminate()
        pool.close()
        pool.join()
    #     printQueue.put("terminated")
    
    #     printing = False
    #     printPool.terminate()
    #     printPool.close()
    #     printPool.join()
    
    
        del pool,sourceQueue,batchQueue
    #     del printPool, printQueue