Search code examples
pythonpython-typingmypy

How to correctly specify type hints with AsyncGenerator and AsyncContextManager


Consider the following code

import contextlib
import abc
import asyncio

from typing import AsyncContextManager, AsyncGenerator, AsyncIterator


class Base:

    @abc.abstractmethod
    async def subscribe(self) -> AsyncContextManager[AsyncGenerator[int, None]]:
        pass

class Impl1(Base):

    @contextlib.asynccontextmanager
    async def subscribe(self) ->  AsyncIterator[ AsyncGenerator[int, None] ]: <-- mypy error here

        async def _generator():
            for i in range(5):
                await asyncio.sleep(1)
                yield i
                    
        yield _generator()

For Impl1.subscribe mypy gives the error

Signature of "subscribe" incompatible with supertype "Base"

What is the correct way to specify type hints in the above case? Or is mypy wrong here?


Solution

  • I just happened to come up with the same problem and found this question on the very same day, but also figured out the answer quickly.

    You need to remove async from the abstract method.

    To explain why, I'll simplify the case to a simple async iterator:

    @abc.abstractmethod
    async def foo(self) -> AsyncIterator[int]:
        pass
    
    async def v1(self) -> AsyncIterator[int]:
        yield 0
    
    async def v2(self) -> AsyncIterator[int]:
        return v1()
    

    If you compare v1 and v2, you'll see that the function signature looks the same, but they actually do very different things. v2 is compatible with the abstract method, v1 is not.

    When you add the async keyword, mypy infers the return type of the function to be a Coroutine. But, if you also put a yield in, it then infers the return type to be AsyncIterator:

    reveal_type(foo)
    # -> typing.Coroutine[Any, Any, typing.AsyncIterator[builtins.int]]
    reveal_type(v1)
    # -> typing.AsyncIterator[builtins.int]
    reveal_type(v2)
    # -> typing.Coroutine[Any, Any, typing.AsyncIterator[builtins.int]]
    

    As you can see, the lack of a yield in the abstract method means that this is inferred as a Coroutine[..., AsyncIterator[int]]. In other words, a function used like async for i in await v2():.

    By removing the async:

    @abc.abstractmethod
    def foo(self) -> AsyncIterator[int]:
        pass
    reveal_type(foo)
    # -> typing.AsyncIterator[builtins.int]
    

    We see that the return type is now AsyncIterator and is now compatible with v1, rather than v2. In other words, a function used like async for i in v1():

    You can also see that this is fundamentally the same thing as v1:

    def v3(self) -> AsyncIterator[int]:
        return v1()
    

    While the syntax is different, both v3 and v1 are functions which will return an AsyncIterator when called, which should be obvious given that we are literally returning the result of v1().