Search code examples
pythonpostgresqlasynchronouspython-asyncioasyncpg

How can I wrap a coroutine.__await__()?


I want to log every acquire/release query to an asyncpg pool. I have written the following for this purpose

class CntPoolLogger:
    def __init__(self) -> None:
        self.conn_cnt = 0
        self.event_cnt = 0
        self.log_filename = "pool_usage.log"
        self.file = open(self.log_filename, "a")
    
    def log(self, event: str) -> None:
        self.file.write(f"{self.event_cnt},{event},{self.conn_cnt}\n")
        self.file.flush()

    def acquire(self) -> None:
        self.event_cnt += 1
        self.conn_cnt += 1
        self.log("acquire")

    def release(self) -> None:
        self.event_cnt += 1
        self.conn_cnt -= 1
        self.log("release")
    
    def __del__(self):
        self.file.close()


class LoggingPoolAcquireContext:
    def __init__(self, pool_acquire_context: asyncpg.pool.PoolAcquireContext, cnt_logger: CntPoolLogger):
        self.pool_acquire_context = pool_acquire_context
        self.cnt_logger = cnt_logger

    async def __aenter__(self, *args, **kwargs):
        res = await self.pool_acquire_context.__aenter__(*args, **kwargs)
        self.cnt_logger.acquire()
        return res

    async def __aexit__(self, *args, **kwargs):
        await self.pool_acquire_context.__aexit__(*args, **kwargs)
        self.cnt_logger.release()
    
    def awaitable_wraper(self, awaitable):
        ??????

    def __await__(self, *args, **kwargs):
        self.cnt_logger.acquire()
        return self.pool_acquire_context.__await__(*args, **kwargs)
        
        

class LoggingAsyncPGPool:
    def __init__(self, pool, cnt_logger: CntPoolLogger):
        self._pool = pool
        self.cnt_logger = cnt_logger

    def acquire(self, *args, **kwargs):
        res = LoggingPoolAcquireContext(self._pool.acquire(*args, **kwargs), cnt_logger=self.cnt_logger)
        return res

    async def release(self, *args, **kwargs):
        await self._pool.release(*args, **kwargs)
        self.cnt_logger.release()
    
    async def close(self, *args, **kwargs):
        await self._pool.close()

I want the counter to be updated only after a connection was actually acquired/released. So I update the counter only after "await" statements. However, in this method

def __await__(self, *args, **kwargs):
        self.cnt_logger.acquire()
        return self.pool_acquire_context.__await__(*args, **kwargs)

we return something awaitable that should be awaited somewhere else. So I am updating the counter without knowing whether a connection was actually acquired or not.

My question is: Is there a way to somehow wrap

self.pool_acquire_context.__await__(*args, **kwargs)

So when we await it somewhere in the future it updates the counter only after a connection was acquired?


Solution

  • As you discovered, __await__ should return a generator. So write one:

    def __await__(self, *args, **kwargs):
        self.cnt_logger.acquire()
        try:
            return yield from self.pool_acquire_context.__await__(*args, **kwargs)
        finally:
            self.cnt_logger.release()