Search code examples
pythonpython-multiprocessing

Python: How do I define a global variable accessible by a multiprocessing pool from command line arguments?


I have a script that is producing files from a large dataset, so I'm using multiprocessing to speed things up. The problem I have is that my script accepts several command line arguments using the argparse library which change the results and I'm struggling to pass the command line arguments to the function called by my multiprocessing pool.

I'm sure the solution to this is really simple I'm just not seeing it. I figured I would make a global variable that gets updated to reflect the command line args but my function called by the pool still has the old value. I've tried to illustrate my problem below:

output_dir = 'default'

def do_task(item):
    print(output_dir) # Prints 'default'
    result = process_item(item)
    write_to_file(data=result, location=os.path.join(output_dir, item.name))

def do_multi_threaded_work(data_path):
    print(output_dir) # Prints command line argument
    data = read_from_file(args.input_file)
    pool = multiprocessing.Pool()
    for i, _ in enumerate(pool.imap_unordered(do_task, data):
        print('Completed task %d/%d' % (i, len(data)))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-o', '--output-dir')
    parser.add_argument('-i', '--input-file')
    args = parser.parse_args()
    output_dir = args.output_dir
    do_multithreaded_work(args.input_file)

How can I ensure that I am saving my files to the correct directory according to the command line arguments?

Edit: It's been suggested I do something like the code below, however considering I have quite a lot of constants (I simplified it to just 1 for this example) in my actual code this seems very messy and counter-intuative. Is there really no better way to just set a global constant accessible by the do_task function, without hard-coding the value?

from itertools import repeat
...
def do_multi_threaded_work(data_path):
    ...
    for i, _ in enumerate(pool.imap_unordered(do_task, zip(data, repeat(output_dir))):

Solution

  • Found a solution that involved using the partial feature of the functools library in the end. This enabled me to specify any constant paramters by creating a partial function with those parameters specified. Then, I pass that partial function along with the iterable to the pool.

    from functools import partial
    
    def do_task(output_dir, item):
        print(output_dir) # Prints 'default'
        result = process_item(item)
        write_to_file(data=result, location=os.path.join(output_dir, item.name))
    
    def do_multi_threaded_work(data_path):
        print(output_dir) # Prints command line argument
        data = read_from_file(args.input_file)
        func = partial(do_task, output_dir)
        pool = multiprocessing.Pool()
        for i, _ in enumerate(pool.imap_unordered(func, data):
            print('Completed task %d/%d' % (i, len(data)))
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('-o', '--output-dir')
        parser.add_argument('-i', '--input-file')
        args = parser.parse_args()
        output_dir = args.output_dir
        do_multithreaded_work(args.input_file)