Search code examples
pythondaskdask-distributeddask-delayed

limit number of CPUs used by dask compute


Below code uses appx 1 sec to execute on an 8-CPU system. How to manually configure number of CPUs used by dask.compute eg to 4 CPUs so the below code will use appx 2 sec to execute even on an 8-CPU system?

import dask
from time import sleep

def f(x):
    sleep(1)
    return x**2

objs = [dask.delayed(f)(x) for x in range(8)]
print(dask.compute(*objs))  # (0, 1, 4, 9, 16, 25, 36, 49)

Solution

  • There are a few options:

    1. specify number of workers at the time of cluster creation
    from dask.distributed import Client
    
    # without specifying unique thread, the function is executed
    # on all threads
    client = Client(n_workers=4, threads_per_worker=1)
    
    # the rest of your code is not changed
    
    1. specify how many (and which) workers should execute a task
    
    client = Client(n_workers=8, threads_per_worker=1)
    
    list_workers = list(client.scheduler_info()['workers'])
    
    client.compute(objs, workers=list_workers[:4]) 
    
    # submit only to the first 4 workers
    # note that workers should still be single-threaded, but the difference
    # from option 1 is that you could in principle have more workers
    # that are idle, also the `workers` kwarg can be passed to
    # dask.compute rather than client.compute
    
    1. specify a semaphore
    from dask.distributed import Client, Semaphore
    
    client = Client()
    sem = Semaphore(max_leases=4, name="foo")
    
    def fmodified(x, sem):
        with sem:
            return f(x)
    
    objs = [dask.delayed(fmodified)(x, sem) for x in range(8)]
    print(dask.compute(*objs))  # (0, 1, 4, 9, 16, 25, 36, 49)
    

    Update: as noted by @mdurant in the comments, if you are running this in a script, then if __name__ == "main": is needed to guard the relevant code from being executed by workers. For example, the second option from the list above would look like this in a script:

    #!/usr/bin/env python3
    import dask
    from dask.distributed import Client
    from time import sleep
    
    def f(x):
        sleep(1)
        return x**2
    
    objs = [dask.delayed(f)(x) for x in range(8)]
    
    if __name__ == "main":
        client = Client(n_workers=8, threads_per_worker=1)
    
        list_workers = list(client.scheduler_info()['workers'])
    
        results = client.compute(objs, workers=list_workers[:4])
    
        print(results)