Search code examples
python-3.xneural-networkmultiprocessingpython-multiprocessing

Multiprocessing in Python for training neural networks simultaneously


I have genetic algorithm that does the hyperparamter search for a neural network for me. I have 10 generations and in each generation, 20 neural networks are generated. But right now for for each generation I am training one network at a time. Therefore, it takes a long time. I instead tried to do multiprocessing where all 20 neural networks are trained in parallel in each generation. However when I do that my neural network information is not being updated. This is what I've done to train 20 neural networks one by one:

def train_networks(networks, dataset):
    """Train each network.
    Args:
        networks (list): Current population of networks
        dataset (str): Dataset to use for training/evaluating
    """


    print('training each network')
    pbar = tqdm(total=len(networks))
    for network in networks:
        print('training network - inside loop')
        network.train(dataset)
        pbar.update(1)
    pbar.close()
    print('done training')

I want to use multiprocessing here. And for multiprocessing I did the following:

def train_networks(networks, dataset):
    """Train each network.
    Args:
        networks (list): Current population of networks
        dataset (str): Dataset to use for training/evaluating
    """


    for network in networks:
        p = multiprocessing.Process(target=network.train,args=(dataset,))
        p.start()
        p.join()

But this doesn't suppose to work. How can I modify my code such that I train all 20 networks in parallel. Help would be appreciated.


Solution

  • The p.join() method stops any further execution of your script until process p is finished. This is one way to do it

    processes = []
    for network in networks:
       p = multiprocessing.Process(target=network.train,args=(dataset,))
       p.start()
       processes.append(p)
    
    # Now you can wait for the networks to finish training before executing the 
    # rest of the script
    
    for process in processes:
       process.join()
    

    Here's a nice resource on multiprocessing