Search code examples
pythonpython-3.xpython-asynciomypyabc

Make sure abstract method would be a coroutine when implemented


First of all, I'm aware it is developers' responsibility to make sure your define your method as a coroutine when implementing child class

class MyBase(ABC):
    @abstractclassmethod
    def the_coroutine(self):
        """
        I want this to be a coroutine
        """

class ImplBaseA(MyBase):
    async def the_coroutine(self):
        return "awaited"

class ImplBaseB(MyBase):
    def the_coroutine(self):
        # some condition that happens quite often
        if True:
            raise ValueError("From the first glance I was even awaited")
        return "not the coroutine"

But how to prevent this issue in the code from occurring?

await a.the_coroutine()
# When inspecting the code it seems like it does the right thing
await b.the_coroutine()
# and when raising exception it really does

Should I use mypy or some similar tool? What's the pythonic way of making sure implementation is coroutine only (regular function only)?


Solution

  • You can augment the type of check done with an ABC base with an __init_subclass__ method that would verify that any overriden methods have their "sinch/asynchness" maintained.

    That would be something along:

    from abc import ABC
    from inspect import iscoroutinefunction
    
    class AsynchSentinelABC(ABC):
        def __init_subclass__(cls, *args, **kw):
            super().__init_subclass__(*args, **kw) 
            for meth_name, meth in cls.__dict__.items():
                coro_status = iscoroutinefunction(meth)
                for parent in cls.__mro__[1:]:
                    if meth_name in parent.__dict__:
                        if coro_status != iscoroutinefunction(getattr(parent, meth_name)):
                            raise TypeError(f"Method {meth_name} is not the same type in child class {cls.__qualname__}")
                        break
    
    class MyBase(AsynchSentinelABC):
        ...
        
    

    Of course, this will imply some artificial restriction as in overriden methods can't be synch functions returning awaitables.

    Detecting this with static type checking and tools like MyPy would also work - I think the most straightforward way would be to create a typing.Protocol spelling out the interface you create now with the abstract base class. (on the other hand, you can use only the static checking mechanism and not need the abstract base any more for other things as well).