Search code examples
pythonpandasdaskdask-distributeddask-delayed

Synchronize dask map_partitions with print functions


I have the following code:

def func1(df):
    x = 1
    print('Processing func1')
    return x

def func2(df):
    x = 2
    print('Processing func2')
    return x

ddf = from_pandas(df, npartitions=3)

print('func1 processing started')
ddf.map_partitions(func1)
print('func1 processing ended')

print('func2 processing started')
ddf.map_partitions(func2)
print('func2 processing ended')

ddf.compute()

What I'm looking for is a way to log (or print in this case) the steps before, during and after each of the map partitions are executed.

However, since the ddf.compute() triggers the map_partitions after the print functions, I get something like this:

func1 processing started
func1 processing ended
func2 processing started
func2 processing ended
Processing func1
Processing func1
Processing func1
Processing func2
Processing func2
Processing func2

Instead, I need

func1 processing started
Processing func1
Processing func1
Processing func1
func1 processing ended
func2 processing started
Processing func2
Processing func2
Processing func2
func2 processing ended

How to make this work? Note: My example uses print, but I would like to synchromize map_partitions with any python function.

UPDATE

A more realistic scenario:

def func1():
   df = dd.read_csv('file.csv', npartitions=3)
   log('In func1') 
   df = func11(df,123)
   log('func1 ended')
   ddf.compute()


def func11(df,x):
    log('In func11')
    # ... do stuff with df
    df = func111(df,x)
    return df 
 

def func111(df, x):
    log('In func111')
    df = df.map_partitions(func1111)
    return df


def func1111(df):
    df['abc'] = df['abc'] * 2
    log('Processing func1111')
    return df

log(msg):
    print(msg) # or log in file or DB

The requirement is that this should print:

In func1
In func11
In func111
Processing func1111
Processing func1111
Processing func1111
func1 ended
   

Solution

  • You can wrap ddf to queue ddf.map_partitions() and log() with ddf.persist() and wait(ddf).

    from dask.distributed import wait
    
    
    class QueuedMapPartitionsWrapper:
        def __init__(self, ddf, queue=None):
            self.ddf = ddf
            self.queue = queue or []
    
        def map_partitions(self, func, *args, **kwargs):
            return self.__class__(self.ddf, self.queue + [(True, func, args, kwargs)])
    
        def log(self, *args, **kwargs):
            return self.__class__(self.ddf, self.queue + [(False, log, args, kwargs)])
    
        def compute(self):
            ddf = self.ddf
            for (map_partitions, func, args, kwargs) in self.queue:
                if map_partitions:
                    ddf = ddf.map_partitions(func, *args, **kwargs)
                else:
                    ddf = ddf.persist()
                    wait(ddf)
                    func(*args, **kwargs)
            return ddf.compute()
    

    Usage 1:

    ddf = dd.from_pandas(df, npartitions=3)
    ddf = QueuedMapPartitionsWrapper(ddf)
    
    ddf = ddf.log('func1 processing started')
    ddf = ddf.map_partitions(func1)
    ddf = ddf.log('func1 processing ended')
    
    ddf = ddf.log('func2 processing started')
    ddf = ddf.map_partitions(func2)
    ddf = ddf.log('func2 processing ended')
    
    ddf.compute()
    

    Usage 2:

    def func1():
       ddf = dd.read_csv('file.csv', npartitions=3)
       ddf = QueuedMapPartitionsWrapper(ddf)    # Either here
    
       log('In func1') 
    
       ddf = func11(ddf, 123)
       # ddf = QueuedMapPartitionsWrapper(ddf)  # or here (just before ddf.log)
       ddf = ddf.log('func1 ended')
    
       ddf.compute()