Search code examples
python-asynciopython-3.7python-decorators

Class as a decorator for regular functions and coroutines in Python


I am trying to create a class as a decorator that will apply a try-except block to the decorated function and retain some log of the exceptions. I want to apply the decorator to both regular functions as well as to coroutines.

I have done the class-as-decorator and it works as designed for regular functions, but something is going wrong with coroutines. Below is some reduced code for a reduced version of the class-as-decorator and a couple of use cases:

import traceback
import asyncio
import functools

class Try:

    def __init__(self, func):
        functools.update_wrapper(self, func)
        self.func = func

    def __call__(self, *args, **kwargs):
        print(f"applying __call__ to {self.func.__name__}")
        try:
            return self.func(*args, **kwargs)
        except:
            print(f"{self.func.__name__} failed")
            print(traceback.format_exc())

    def __await__(self, *args, **kwargs):
        print(f"applying __await__ to {self.func.__name__}")
        try:
            yield self.func(*args, **kwargs)
        except:
            print(f"{self.func.__name__} failed")
            print(traceback.format_exc())

# Case 1
@Try
def times2(x):
    return x*2/0

# Case 2
@Try
async def times3(x):
    await asyncio.sleep(0.0001)
    return x*3/0

async def test_try():
    return await times3(10)

def main():
    times2(10)
    asyncio.run(test_try())
    print("All done")

if __name__ == "__main__":
    main()

Here's the output of the above code (with minor edits):

applying __call__ to times2
times2 failed
Traceback (most recent call last):
  File "<ipython-input-3-37071526b2e6>", line 14, in __call__
    return self.func(*args, **kwargs)
  File "<ipython-input-3-37071526b2e6>", line 30, in times2
    return x*2/0
ZeroDivisionError: division by zero

applying __call__ to times3
Traceback (most recent call last):
  File "[...]/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3296, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-37071526b2e6>", line 46, in <module>
    main()
  File "<ipython-input-3-37071526b2e6>", line 43, in main
    asyncio.run(test_try())
  File "[...]/lib/python3.7/asyncio/runners.py", line 43, in run
    return loop.run_until_complete(main)
  File "[...]/lib/python3.7/asyncio/base_events.py", line 579, in run_until_complete
    return future.result()
  File "<ipython-input-3-37071526b2e6>", line 39, in test_try
    return await times3(10)
  File "<ipython-input-3-37071526b2e6>", line 36, in times3
    return x*3/0
ZeroDivisionError: division by zero

Case 1 behaves normally: as expected __call__ is called, then the decorated function which fails and the exception is caught. But I can't explain the behaviour of Case 2. Notice the missing "times3 failed" and "All done" print at the end. I can't reproduce the color coded output here but Case 1's traceback is regular print while Case 2's traceback is exception red (on PyCharm). The surprising part is that the __call__ method was called instead of __await__.

I have tried another class-as-decorator, one that keeps a tally of the number of times a function was called. That works just fine with __call__ with either regular functions or coroutines.

So what is actually going on? Do I need to somehow force the function to use __await__? How?

I tried the following:

async def test_try2():
    func = await times3

with output

applying __await__ to times3
times3 failed
Traceback (most recent call last):
  File "<ipython-input-5-5a85f988097e>", line 22, in __await__
    yield self.func(*args, **kwargs)
TypeError: times3() missing 1 required positional argument: 'x'

which does force using __await__ but then what?


Solution

  • The problem with your code is that it places __await__ on the wrong object. Generally await f(x) expands to something like:

    _awaitable = f(x)
    _iter = _awaitable.__await__()
    yield from _iter  # not literally[1]
    

    Note how __await__() is called on the result on the function, not on the function object itself. What happens in your times3 example is the following:

    • __call__ calls the original times3 coroutine function in self.func, which trivially constructs a coroutine object. There is no exception at this point because the object didn't start executing yet, so a coroutine object (what you get by calling an async def coroutine function) is returned.

    • __await__ is invoked on the coroutine object obtained by running self.func, which is the original times3 async def, and not on your function wrapper. This is because, in terms of the pseudocode above, your wrapper corresponds to f, whereas __await__() is invoked on the _awaitable, which in your case is the result of calling f.

    In general you can't know whether the result of a function call will ever be awaited. But since coroutine objects are not useful for anything other than awaiting them (and they even print a warning when destroyed without being awaited), you can safely assume so. This assumption allows your __call__ to check whether the result of the function call is awaitable and, if so, wrap it in an object that will implement your wrapping logic on the __await__ level:

    ...
    import collections.abc
    
    class Try:
        def __init__(self, func):
            functools.update_wrapper(self, func)
            self.func = func
    
        def __call__(self, *args, **kwargs):
            print(f"applying __call__ to {self.func.__name__}")
            try:
                result = self.func(*args, **kwargs)
            except:
                print(f"{self.func.__name__} failed")
                print(traceback.format_exc())
                return
            if isinstance(result, collections.abc.Awaitable):
                # The result is awaitable, wrap it in an object
                # whose __await__ will call result.__await__()
                # and catch the exceptions.
                return TryAwaitable(result)
            return result
    
    class TryAwaitable:
        def __init__(self, awaitable):
            self.awaitable = awaitable
    
        def __await__(self, *args, **kwargs):
            print(f"applying __await__ to {self.awaitable.__name__}")
            try:
                return yield from self.awaitable.__await__()
            except:
                print(f"{self.awaitable.__name__} failed")
                print(traceback.format_exc())
    

    This results in the expected output:

    applying __call__ to times3
    applying __await__ to times3
    times3 failed
    Traceback (most recent call last):
      File "wrap3.py", line 30, in __await__
        yield from self.awaitable.__await__()
      File "wrap3.py", line 44, in times3
        return x*3/0
    ZeroDivisionError: division by zero
    

    Note that your implementation of __await__ had an unrelated problem, it delegated to the function using yield. One must use yield from instead, because that allows the underlying iterable to choose when to suspend, and also to provide a value once it stops suspending. A bare yield suspends unconditionally (and only once) which is incompatible with the semantics of await.

    1 Not literally because yield from is not allowed in async def. But async def behaves as if such a generator was returned by the __await__ method of the object it returns.