Search code examples
pythonpython-3.xpython-asynciopython-3.7

Test if coroutine was awaited or not


I have an asynchronous function which connects to a database. Currently my users do:

conn = await connect(uri, other_params)

I want to continue to support this, but want to additionally allow connect() to be used as a context manager:

async with connect(uri, other_params) as conn:
     pass

The difference between these two scenarios is that in the first case connect is awaited, and in the second case it is not.

Is it possible to tell, within the body of connect, if the coroutine was awaited or not?

My current effort at this on repl.it.


Solution

  • Here's code that passes tests you provided:

    import asyncio
    import pytest
    from functools import wraps
    
    
    def connection_context_manager(func):
      @wraps(func)
      def wrapper(*args, **kwargs):
    
        class Wrapper:
            def __init__(self):
              self._conn = None
    
            async def __aenter__(self):
                self._conn = await func(*args, **kwargs)
                return self._conn
    
            async def __aexit__(self, *_):
              await self._conn.close()
    
            def __await__(self):
                return func(*args, **kwargs).__await__()  # https://stackoverflow.com/a/33420721/1113207
        return Wrapper()
    
      return wrapper
    

    Note how three magic methods allows us to make object awaitable and async context manager at the same time.

    Feel free to ask questions if you have any.