Search code examples
pythonmypy

Type hints for class decorator


I have a class decorator which removes one method and adds another to a class.

How could I provide type hints for that? I've obviously tried to research this myself, to no avail.

Most people claim this requires an intersection type. Is there any recommended solution? Something I'm missing?

Example code:

class MyProtocol(Protocol):
    def do_check(self) -> bool:
        raise NotImplementedError

def decorator(clazz: type[MyProtocol]) -> ???:
    do_check: Callable[[MyProtocol], bool] = getattr(clazz, "do_check")

    def do_assert(self: MyProtocol) -> None:
        assert do_check(self)

    delattr(clazz, "do_check")
    setattr(clazz, "do_assert", do_assert)
    
    return clazz

@decorator
class MyClass(MyProtocol):
    def do_check(self) -> bool:
        return False

mc = MyClass()
mc.do_check() # hints as if exists, but doesn't
mc.do_assert() # no hints, but works

I guess what I'm looking for is the correct return type for decorator.


Solution

  • There is no type annotation which can do what you want. Even with intersection types, there won't be a way to express the action of deleting an attribute - the best you can do is making an intersection with a type which overrides do_check with some kind of unusable descriptor.

    What you're asking for can be instead done with a mypy plugin. The result can look like this after a basic implementation:

    from package.decorator_module import MyProtocol, decorator
    
    @decorator
    class MyClass(MyProtocol):
        def do_check(self) -> bool:
            return False
    
    >>> mc = MyClass()  # mypy: Cannot instantiate abstract class "MyClass" with abstract attribute "do_check" [abstract]
    >>> mc.do_check()   # raises `NotImplementedError` at runtime
    >>> mc.do_assert()  # OK
    

    Note that mc.do_check exists but is detected to be an abstract method by the plugin. This matches the runtime implementation, as delattr deleting MyClass.do_check merely exposes the parent MyProtocol.do_check instead, and non-overridden methods on a typing.Protocol are abstract methods and you can't instantiate the class without overriding them.


    Here's a basic implementation. Use the following directory structure:

    project/
      mypy.ini
      mypy_plugin.py
      test.py
      package/
        __init__.py
        decorator_module.py
    

    Contents of mypy.ini

    [mypy]
    plugins = mypy_plugin.py
    

    Contents of mypy_plugin.py

    from __future__ import annotations
    
    import typing_extensions as t
    
    import mypy.plugin
    import mypy.plugins.common
    import mypy.types
    
    if t.TYPE_CHECKING:
        import collections.abc as cx
        import mypy.nodes
    
    def plugin(version: str) -> type[DecoratorPlugin]:
        return DecoratorPlugin
    
    class DecoratorPlugin(mypy.plugin.Plugin):
    
        # See https://mypy.readthedocs.io/en/stable/extending_mypy.html#current-list-of-plugin-hooks
        # Since this is a class definition modification with a class decorator
        # and the class body should have been semantically analysed by the time
        # the class definition is to be manipulated, we choose
        # `get_class_decorator_hook_2`
        def get_class_decorator_hook_2(
            self, fullname: str
        ) -> cx.Callable[[mypy.plugin.ClassDefContext], bool] | None:
            if fullname == "package.decorator_module.decorator":
                return class_decorator_hook
            return None
    
    def class_decorator_hook(ctx: mypy.plugin.ClassDefContext) -> bool:
        mypy.plugins.common.add_method_to_class(
            ctx.api,
            cls=ctx.cls,
            name="do_assert",
            args=[],  # Instance method with (1 - number of bound params) arguments, i.e. 0 arguments
            return_type=mypy.types.NoneType(),
            self_type=ctx.api.named_type(ctx.cls.fullname),
        )
        del ctx.cls.info.names["do_check"]  # Remove `do_check` from the class
        return True  # Returns whether class is fully defined or needs another round of semantic analysis
    

    Contents of test.py

    from package.decorator_module import MyProtocol, decorator
    
    @decorator
    class MyClass(MyProtocol):
        def do_check(self) -> bool:
            return False
    
    mc = MyClass()  # mypy: Cannot instantiate abstract class "MyClass" with abstract attribute "do_check" [abstract]
    mc.do_check()   # raises `NotImplementedError` at runtime
    mc.do_assert()  # OK
    

    Contents of package/decorator_module.py

    from __future__ import annotations
    
    import typing_extensions as t
    
    if t.TYPE_CHECKING:
        import collections.abc as cx
        _T = t.TypeVar("_T")
    
    class MyProtocol(t.Protocol):
        def do_check(self) -> bool:
            raise NotImplementedError
    
    # The type annotations here don't mean anything for the mypy plugin,
    # which does its own magic when it sees `@package.decorator_module.decorator`.
    def decorator(clazz: type[_T]) -> type[_T]:
    
        do_check: cx.Callable[[_T], bool] = getattr(clazz, "do_check")
    
        def do_assert(self: _T) -> None:
            assert do_check(self)
    
        delattr(clazz, "do_check")
        setattr(clazz, "do_assert", do_assert)
    
        return clazz