Search code examples
python-3.xdaskdask-distributed

Dask worker post-processing


I'm new to dask and am trying to implement some post-processing tasks when workers shutdown. I'm currently using an EC2Cluster with n_workers=5

The cluster is created each time I need to run my large task. The task outputs a bunch of files which I want to send to AWS S3.

How would I implement a "post-processing" function that would run on each worker to send any logs and outputs to my AWS S3?

Thanks in advance

def complex():
    time.sleep(10)
    print('hard')
    print(get_worker().id)

    return 'hard'


class DaskWorkerHandler(WorkerPlugin):
    """
    Worker life-cycle handler
    """
    def __init__(self):
        self.worker_id = None

    def setup(self, worker):
        self.worker_id = worker.id

    def teardown(self, worker):
        print(f"tearing down - {self.worker_id}. thread={threading.get_ident()}")

        # some post processing on the worker server
        # eg. post files to S3 etc...


if __name__ == '__main__':
    cluster = LocalCluster(n_workers=5)
    print(f"cluster_name={cluster.name}")

    shutdown_handler = DaskWorkerHandler()
    client = Client(cluster)
    client.register_worker_plugin(shutdown_handler)

    future = client.submit(complex)
    result = future.result()

Solution

  • You can use Python’s standard logging module to log whatever you'd like as the workers are running and then use the worker plugin you wrote to save these logs to an S3 bucket on teardown (check out the docs on logging in Dask for more details). Here's an example:

    import dask
    from dask.distributed import Client, LocalCluster, get_worker
    from dask.distributed.diagnostics.plugin import WorkerPlugin
    import fsspec
    import logging
    
    def complex():
        logger = logging.getLogger("distributed.worker")
        logger.error("Got here")
        return 'hard'
    
    
    class DaskWorkerHandler(WorkerPlugin):
        """Worker life-cycle handler."""
        def __init__(self):
            self.worker_id = None
    
        def setup(self, worker):
            self.worker_id = worker.id
    
        def teardown(self, worker):
            logs = worker.get_logs()
            # replace with S3 path
            with fsspec.open(f"worker-{self.worker_id}-logs.txt", "w") as f:
                f.write("\n".join([str(log) for log in logs]))
    
    
    cluster = LocalCluster(n_workers=5)
    client = Client(cluster)
    
    shutdown_handler = DaskWorkerHandler()
    client.register_worker_plugin(shutdown_handler)
    
    future = client.submit(complex)
    result = future.result()
    
    
    client.close()
    cluster.close()