Search code examples
pythonpython-3.xnumpymultiprocessing

Using multiprocessing Pool with a function that takes a list of arrays


I am trying to write a function to read a large number of files in parallel. The code I have for this goes like:

import numpy as np
from multiprocessing import Pool
from functools import partial

def read_profiles(stamp,name,cols,*args):
   #This function reads each file.
   filename=name + '-' + str(int(timestep[stamp])) + '.dat'
   with open(filename) as f:
      xloc = 0
      for line in f:
         ele = line.rstrip("\n").split()
         for g in len(args):
            args[g][stamp,xloc] = float(ele[cols[g]])
         xloc = xloc + 1

timestep = np.arange(1,51)
x = np.ndarray(shape=(len(timestep),1001))
Ex = np.ndarray(shape=(len(timestep),1001))
j1 = np.ndarray(shape=(len(timestep),1001))
j2 = np.ndarray(shape=(len(timestep),1001))
j3 = np.ndarray(shape=(len(timestep),1001))
j4 = np.ndarray(shape=(len(timestep),1001))

terse_args = [x,Ex]
curr_args = [j1,j2,j3,j4]

with Pool(4) as pool:
   pool.map(partial(read_profiles,name='terse',cols=[0,2],*args=*terse_args),range(len(timestep)))
   pool.map(partial(read_profiles,name='current',cols=[1,2,3,4],*args=*curr_args),range(len(timestep)))

Note that the last argument (*args) takes an unknown number of 2D arrays. The above code gives me an error saying "invalid syntax" at *args. I have tried passing them as positional arguments without the keywords but then I get an error saying multiple values for 'name'.

Does anyone know how I can include an arbitrary number of 2D arrays as an argument to the function while using pool.map and partial?

Please let me know if any other information is required. Thank you.


Solution

  • One possible solution is to create a custom partial (could be a class with __call__ magic method), e.g.:

    from functools import partial
    from multiprocessing import Pool
    
    import numpy as np
    
    
    def read_profiles(stamp, name, cols, *args):
        # This function reads each file.
        filename = name + "-" + str(int(timestep[stamp])) + ".dat"
        print(f"Opening {filename=}")
    
    
    class MyPartial:
        def __init__(self, name, cols, *args):
            self.name = name
            self.cols = cols
            self.args = args
    
        def __call__(self, stamp):
            return read_profiles(
                stamp,
                self.name,
                self.cols,
                *self.args,
            )
    
    
    if __name__ == "__main__":
        timestep = np.arange(1, 51)
        x = np.ndarray(shape=(len(timestep), 1001))
        Ex = np.ndarray(shape=(len(timestep), 1001))
        j1 = np.ndarray(shape=(len(timestep), 1001))
        j2 = np.ndarray(shape=(len(timestep), 1001))
        j3 = np.ndarray(shape=(len(timestep), 1001))
        j4 = np.ndarray(shape=(len(timestep), 1001))
    
        terse_args = [x, Ex]
        curr_args = [j1, j2, j3, j4]
    
        with Pool(4) as pool:
            pool.map(
                MyPartial("terse", [0, 2], *terse_args),
                range(len(timestep)),
            )
    
            pool.map(
                MyPartial("current", [1, 2, 3, 4], *curr_args),
                range(len(timestep)),
            )
    

    Prints:

    ...
    
    Opening filename='current-46.dat'
    Opening filename='current-47.dat'
    Opening filename='current-48.dat'
    Opening filename='current-49.dat'
    Opening filename='current-50.dat'