Search code examples
daskdask-delayed

Retries in dask.compute() are unclear


From the documentation, Number of allowed automatic retries if computing a result fails.

Does "result" refer to each individual task or the entire compute() call?

If it refers to the entire call, how to implement retries for each task in dask.delayed?

Also, I'm not sure if the retries are working at all, as per below code.

import dask
import random

@dask.delayed
def add(x, y):
    return x + y

@dask.delayed
def divide(sum_i):
    n = random.randint(0, 1)
    result = sum_i / n
    return result

tasks = []
for i in range(3):
    sum_i = add(i, i+1)
    divide_n = divide(sum_i)
    tasks.append(divide_n)

dask.compute(*tasks, retries=1000)

Expected output is (1, 3, 5), actual is ZeroDivisionError.


Solution

  • If anyone is interested, we use a @retry decorator for tasks, like so:

    @dask.delayed
    @retry(Exception, tries=3, delay=5)
    def my_func():
        pass
    

    Retry decorator:

    from functools import wraps
    
    def retry(exceptions, tries=4, delay=3, backoff=2, logger=None):
        """
        Retry calling the decorated function using an exponential backoff.
    
        Args:
            exceptions: The exception to check. may be a tuple of
                exceptions to check.
            tries: Number of times to try (not retry) before giving up.
            delay: Initial delay between retries in seconds.
            backoff: Backoff multiplier (e.g. value of 2 will double the delay
                each retry).
            logger: Logger to use.
    
        """
        if not logger:
            logger = logging.getLogger(__name__)
    
        def deco_retry(f):
            @wraps(f)
            def f_retry(*args, **kwargs):
                mtries, mdelay = tries, delay
                while mtries > 1:
                    try:
                        return f(*args, **kwargs)
                    except exceptions as e:
                        msg = f"{e}, \nRetrying in {mdelay} seconds..."
                        logger.warning(msg)
                        sleep(mdelay)
                        mtries -= 1
                        mdelay *= backoff
                return f(*args, **kwargs)
            return f_retry  # true decorator
    
        return deco_retry