Search code examples
pythonfunctionparallel-processingmultiprocessingpool

Why is my Python parallel process running too many times?


I've written a function that carries out a stochastic simulation on a system of chemical reactions.

def gillespie_tau_leaping(start_state, LHS, stoch_rate, state_change_array): # inputs are a series of arrays
    t = SimulationTimer()
    t.start()
    #update molecule numbers for each species in model
    #update current time of the system 
    t.stop()
    print(f"Accumulated time: {t.get_accumulated_time():0.10f} seconds")
return popul_num_all, tao_all, t.get_accumulated_time() # popul_num_all is an array of changing molecule numbers over time, tao_all is the evolution of time throughout the simulation

popul_num_all, tao_all, accumulated_elapsed_time = gillespie_tau_leaping(start_state, LHS, stoch_rate, state_change_array)  # Call the function to make variables accessible for plotting. 

I've now written the following code to run the gillespie_tau_leaping function multiple times using pythons multiprocessing pool class.

if __name__ == '__main__':
    with Pool() as p:
        pool_results = p.starmap(gillespie_tau_leaping, [(start_state, LHS, stoch_rate, state_change_array) for i in range(4)])
        p.close()
        p.join()   
        total_time = 0.0
        for tuple_results in pool_results:
            total_time += tuple_results[2]
    print(f"Total time:\n{total_time}") 


def gillespie_plot(tao_all, popul_num):
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.plot(tao_all, popul_num_all[:, 0], label='S', color= 'Green')
    ax1.legend()
    for i, label in enumerate(['T', 'U']):
        ax2.plot(tao_all, popul_num_all[:, i+1], label=label)
    ax2.legend()
    plt.tight_layout()
    plt.show()
    return fig

gillespie_plot(tao_all, popul_num_all)

gillespie_plot, plots the results of changing molecule numbers, popul_num_all, over time, tao_all.

Only when I run this code it simulates gillespie_tau_leaping 9 times. The first time is because I make a call to the function, to make some variables accessible. Its the next 8 simulations I don't understand, the first 4 simulate the system and plot the graphs but don't return total_time of the parallel simulations. The second 4 simulations don't plot the graphs but do return the total_time of the parallel simulations.

I'm only expecting/wanting 4 simulations after the function call, which plot the graphs and return total_time

What am I doing wrong?

Cheers


Solution

  • The breakdown of your program is as follows:

    1. On startup the main process defines gillespie_tau_leaping() and then calls it as popul_num_all, tao_all, accumulated_elapsed_time = gillespie_tau_leaping(start_state, LHS, stoch_rate, state_change_array)

    2. The main process evaluates __name__ == '__main__' as true and so starts the multiprocessing pool and repeats the following steps 4 times

      1. The subprocess loads up and then defines and calls gillespie_tau_leaping() (as is in the other step 1)
      2. The subprocess evaluates __name__ == '__main__' as false and does not create a new pool.
      3. The subprocess then defines and calls gillespie_plot() with the arguments from step 2.1
      4. The subprocess receives the request to call gillespie_tau_leaping() that comes from the starmap call and processes it, returning the result.
    3. The main process receives the results of the starmap invocation (2.4) and prints out the time result

    4. The main process defines and calls gillespie_plot() with the arguments from step 1

    To run your code only 4 times you should do:

    def gillespie_tau_leaping(start_state, LHS, stoch_rate, state_change_array):
        ...
    
    def gillespie_plot(tao_all, popul_num):
        fig, (ax1, ax2) = plt.subplots(1, 2)
        ax1.plot(tao_all, popul_num_all[:, 0], label='S', color= 'Green')
        ax1.legend()
        for i, label in enumerate(['T', 'U']):
            ax2.plot(tao_all, popul_num_all[:, i+1], label=label)
        ax2.legend()
        plt.tight_layout()
        # plt.show() # this can block
        return fig
    
    if __name__ == '__main__':
        with Pool() as p:
            pool_results = p.starmap(gillespie_tau_leaping, [(start_state, LHS, stoch_rate, state_change_array) for i in range(4)])
            # pool is implicitly closed and joined at the end of the with block. 
        
        total_time = 0.0
        for tuple_results in pool_results:
            total_time += tuple_results[2]
        print(f"Total time:\n{total_time}")
    
        for tao_all, popul_num_all, _total_time in pool_results:
            gillespie_plot(tao_all, popul_num_all)
        plt.show()