Search code examples
pythonmultithreadingmultiprocessinggenerator

How to make a generator thread-safe?


I have a generator that looks like this:

def data_generator(data_file, index_list,....):
      orig_index_list = index_list
    while True:
        x_list = list()
        y_list = list()
        if patch_shape:
            index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
                                                 patch_overlap, patch_start_offset,pred_specific=pred_specific)
        else:
            index_list = copy.copy(orig_index_list)

        while len(index_list) > 0:
            index = index_list.pop()
            add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
                     augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                     skip_blank=skip_blank, permute=permute)
            if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
                x_list = list()
                y_list = list()

My dataset size is 55GB and stored as a .h5 file (data.h5). It is extremely slow when reading the data. It takes 7000s for one epoch and I get a segmentation fault after like 6 epochs.

I thought if I set multi_processing = False and workers > 1 it will speed up reading data:

model.fit(multi_processing = False, workers = 8)

But when I do that I get the following error:

RuntimeError: Your generator is NOT thread-safe. Keras requires a thread-safe generator when use_multiprocessing=False, workers > 1.

Is there a way to make my generator thread-safe? Or is there any other efficient way to generate this data?


Solution

  • I believe the LockedIterator class I referenced in my comment above is incorrect and should be as coded in the example below:

    import threading
    
    class LockedIterator(object):
        def __init__(self, it):
            self.lock = threading.Lock()
            self.it = iter(it)
    
        def __iter__(self): return self
    
        def __next__(self):
            with self.lock:
                return self.it.__next__()
                
    def gen():
        for x in range(10):
            yield x
    
    new_gen = LockedIterator(gen())
    
    def worker(g):
        for x in g:
            print(x, flush=True)
    
    t1 = threading.Thread(target=worker, args=(new_gen,))
    t2 = threading.Thread(target=worker, args=(new_gen,))
    t1.start()
    t2.start()
    t1.join()
    t2.join()
    

    Prints:

    0
    1
    23
    
    4
    5
    6
    7
    8
    9
    

    If you want to guarantee that the printed output prints one value per line, then we would also need to pass a threading.Lock instance to each thread and issue the print statement under control of that lock so printing is serialized.