Search code examples
pythondecoratormypypython-typing

How to type annotate optionally parametrized decorator, which use third party decorator inside


Faced with problem, when using mypy on my project. At first I use a backoff package to do some retries on some functions/methods. Then I realised, that most of options are just repeated, so I created per-project decorator, with all common values filled for backoff decorator. But, I don't know how to annotate such "decorator in decorator". More to say, this should work with sync/async function/method matrix. Here is code to reproduce my pain:

import asyncio
from collections.abc import Awaitable, Callable
from functools import wraps
from typing import Any, Literal, ParamSpec, TypeVar

T = TypeVar("T")
P = ParamSpec("P")
AnyCallable = Callable[P, T | Awaitable[T]]
Decorator = Callable[[AnyCallable[P, T]], AnyCallable[P, T]]


def third_party_decorator(
        a: int | None = None,
        b: str | None = None,
        c: Literal[None] = None,
        d: bool | None = None,
        e: str | None = None,
    ) -> Decorator[P, T]:
    def decorator(f: AnyCallable[P, T]) -> AnyCallable[P, T]:
        @wraps(f)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
            print(f"third_party_decorator {f = } {a = } {b = } {c = } {d = } {e = }")
            return f(*args, **kwargs)
        return wrapper
    return decorator


def parametrized_decorator(f: AnyCallable[P, T] | None = None, **kwargs: Any) -> Decorator[P, T] | AnyCallable[P, T]:
    def decorator(f: AnyCallable[P, T]) -> AnyCallable[P, T]:
        defaults = {"a": 1, "b": "b", "c": None, "d": True}
        defaults.update(kwargs)
        print(f"parametrized_decorator {f = } {defaults = }")
        decorator: Decorator[P, T] = third_party_decorator(**defaults)
        wrapped = decorator(f)
        return wrapped

    if f is None:
        return decorator
    else:
        return decorator(f)


@parametrized_decorator
def sync_straight_function(x: int = 0) -> None:
    print(f"sync_straight_function {x = }")


@parametrized_decorator(b="B", e="e")
def sync_parametrized_function(x: str = "abc", y: bool = False) -> None:
    print(f"sync_parametrized_function {x = } {y = }")


@parametrized_decorator
async def async_straight_function(x: int = 0) -> None:
    print(f"sync_straight_function {x = }")


@parametrized_decorator(b="B", e="e")
async def async_parametrized_function(x: str = "abc", y: bool = False) -> None:
    print(f"sync_parametrized_function {x = } {y = }")


class Foo:
    @parametrized_decorator
    def sync_straight_method(self, x: int = 0) -> None:
        print(f"sync_straight_function {x = }")

    @parametrized_decorator(b="B", e="e")
    def sync_parametrized_method(self, x: str = "abc", y: bool = False) -> None:
        print("sync_parametrized_method", x, y)

    @parametrized_decorator
    async def async_straight_method(self, x: int = 0) -> None:
        print(f"sync_straight_function {x = }")

    @parametrized_decorator(b="B", e="e")
    async def async_parametrized_method(self, x: str = "abc", y: bool = False) -> None:
        print(f"sync_parametrized_function {x = } {y = }")


def main_sync_functions() -> None:
    sync_straight_function()
    sync_straight_function(1)

    sync_parametrized_function()
    sync_parametrized_function("xyz", True)


async def main_async_functions() -> None:
    await async_straight_function()
    await async_straight_function(1)

    await async_parametrized_function()
    await async_parametrized_function("xyz", True)


def main_sync_methods() -> None:
    foo = Foo()
    foo.sync_straight_method()
    foo.sync_straight_method(1)

    foo.sync_parametrized_method()
    foo.sync_parametrized_method("xyz", True)


async def main_async_methods() -> None:
    foo = Foo()
    await foo.async_straight_method()
    await foo.async_straight_method(1)

    await foo.async_parametrized_method()
    await foo.async_parametrized_method("xyz", True)


if __name__ == "__main__":
    print("\nSYNC FUNCTIONS:")
    main_sync_functions()
    print("\nASYNC FUNCTIONS:")
    asyncio.run(main_async_functions())
    print("\nSYNC METHODS:")
    main_sync_methods()
    print("\nASYNC METHODS:")
    asyncio.run(main_async_methods())

The output of mypy have 44 errors:

parametrized-decorator-typing.py:33: error: Argument 1 to "third_party_decorator" has incompatible type "**dict[str, object]"; expected "int | None"  [arg-type]
            decorator: Decorator[P, T] = third_party_decorator(**defaults)
                                                                 ^~~~~~~~
parametrized-decorator-typing.py:33: error: Argument 1 to "third_party_decorator" has incompatible type "**dict[str, object]"; expected "str | None"  [arg-type]
            decorator: Decorator[P, T] = third_party_decorator(**defaults)
                                                                 ^~~~~~~~
parametrized-decorator-typing.py:33: error: Argument 1 to "third_party_decorator" has incompatible type "**dict[str, object]"; expected "None"  [arg-type]
            decorator: Decorator[P, T] = third_party_decorator(**defaults)
                                                                 ^~~~~~~~
parametrized-decorator-typing.py:33: error: Argument 1 to "third_party_decorator" has incompatible type "**dict[str, object]"; expected "bool | None"  [arg-type]
            decorator: Decorator[P, T] = third_party_decorator(**defaults)
                                                                 ^~~~~~~~
parametrized-decorator-typing.py:48: error: Argument 1 has incompatible type "Callable[[str, bool], None]"; expected
"Callable[[VarArg(<nothing>), KwArg(<nothing>)], <nothing> | Awaitable[<nothing>]]"  [arg-type]
    @parametrized_decorator(b="B", e="e")
     ^
parametrized-decorator-typing.py:48: error: Argument 1 has incompatible type "Callable[[str, bool], None]"; expected <nothing>  [arg-type]
    @parametrized_decorator(b="B", e="e")
     ^
parametrized-decorator-typing.py:53: error: Argument 1 to "parametrized_decorator" has incompatible type "Callable[[int], Coroutine[Any, Any, None]]"; expected
"Callable[[int], <nothing> | Awaitable[<nothing>]] | None"  [arg-type]
    @parametrized_decorator
     ^
parametrized-decorator-typing.py:58: error: Argument 1 has incompatible type "Callable[[str, bool], Coroutine[Any, Any, None]]"; expected
"Callable[[VarArg(<nothing>), KwArg(<nothing>)], <nothing> | Awaitable[<nothing>]]"  [arg-type]
    @parametrized_decorator(b="B", e="e")
     ^
parametrized-decorator-typing.py:58: error: Argument 1 has incompatible type "Callable[[str, bool], Coroutine[Any, Any, None]]"; expected <nothing>  [arg-type]
    @parametrized_decorator(b="B", e="e")
     ^
parametrized-decorator-typing.py:68: error: Argument 1 has incompatible type "Callable[[Foo, str, bool], None]"; expected
"Callable[[VarArg(<nothing>), KwArg(<nothing>)], <nothing> | Awaitable[<nothing>]]"  [arg-type]
        @parametrized_decorator(b="B", e="e")
         ^
parametrized-decorator-typing.py:68: error: Argument 1 has incompatible type "Callable[[Foo, str, bool], None]"; expected <nothing>  [arg-type]
        @parametrized_decorator(b="B", e="e")
         ^
parametrized-decorator-typing.py:72: error: Argument 1 to "parametrized_decorator" has incompatible type "Callable[[Foo, int], Coroutine[Any, Any, None]]"; expected
"Callable[[Foo, int], <nothing> | Awaitable[<nothing>]] | None"  [arg-type]
        @parametrized_decorator
         ^
parametrized-decorator-typing.py:76: error: Argument 1 has incompatible type "Callable[[Foo, str, bool], Coroutine[Any, Any, None]]"; expected
"Callable[[VarArg(<nothing>), KwArg(<nothing>)], <nothing> | Awaitable[<nothing>]]"  [arg-type]
        @parametrized_decorator(b="B", e="e")
         ^
parametrized-decorator-typing.py:76: error: Argument 1 has incompatible type "Callable[[Foo, str, bool], Coroutine[Any, Any, None]]"; expected <nothing>  [arg-type]
        @parametrized_decorator(b="B", e="e")
         ^
parametrized-decorator-typing.py:82: error: Too few arguments  [call-arg]
        sync_straight_function()
        ^~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:83: error: Argument 1 has incompatible type "int"; expected "Callable[[int], Awaitable[None] | None]"  [arg-type]
        sync_straight_function(1)
                               ^
parametrized-decorator-typing.py:85: error: "Awaitable[<nothing>]" not callable  [operator]
        sync_parametrized_function()
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:86: error: "Awaitable[<nothing>]" not callable  [operator]
        sync_parametrized_function("xyz", True)
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:86: error: Argument 1 has incompatible type "str"; expected <nothing>  [arg-type]
        sync_parametrized_function("xyz", True)
                                   ^~~~~
parametrized-decorator-typing.py:86: error: Argument 2 has incompatible type "bool"; expected <nothing>  [arg-type]
        sync_parametrized_function("xyz", True)
                                          ^~~~
parametrized-decorator-typing.py:90: error: Incompatible types in "await" (actual type "Callable[[int], <nothing> | Awaitable[<nothing>]] | Awaitable[<nothing>]",
expected type "Awaitable[Any]")  [misc]
        await async_straight_function()
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:90: error: Too few arguments  [call-arg]
        await async_straight_function()
              ^~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:91: error: Incompatible types in "await" (actual type "Callable[[int], <nothing> | Awaitable[<nothing>]] | Awaitable[<nothing>]",
expected type "Awaitable[Any]")  [misc]
        await async_straight_function(1)
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:91: error: Argument 1 has incompatible type "int"; expected "Callable[[int], <nothing> | Awaitable[<nothing>]]"  [arg-type]
        await async_straight_function(1)
                                      ^
parametrized-decorator-typing.py:93: error: "Awaitable[<nothing>]" not callable  [operator]
        await async_parametrized_function()
              ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:94: error: "Awaitable[<nothing>]" not callable  [operator]
        await async_parametrized_function("xyz", True)
              ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:94: error: Argument 1 has incompatible type "str"; expected <nothing>  [arg-type]
        await async_parametrized_function("xyz", True)
                                          ^~~~~
parametrized-decorator-typing.py:94: error: Argument 2 has incompatible type "bool"; expected <nothing>  [arg-type]
        await async_parametrized_function("xyz", True)
                                                 ^~~~
parametrized-decorator-typing.py:99: error: Too few arguments  [call-arg]
        foo.sync_straight_method()
        ^~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:100: error: Argument 1 has incompatible type "int"; expected "Callable[[Foo, int], Awaitable[None] | None]"  [arg-type]
        foo.sync_straight_method(1)
                                 ^
parametrized-decorator-typing.py:100: error: Argument 1 has incompatible type "int"; expected "Foo"  [arg-type]
        foo.sync_straight_method(1)
                                 ^
parametrized-decorator-typing.py:102: error: "Awaitable[<nothing>]" not callable  [operator]
        foo.sync_parametrized_method()
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:103: error: "Awaitable[<nothing>]" not callable  [operator]
        foo.sync_parametrized_method("xyz", True)
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:103: error: Argument 1 has incompatible type "str"; expected <nothing>  [arg-type]
        foo.sync_parametrized_method("xyz", True)
                                     ^~~~~
parametrized-decorator-typing.py:103: error: Argument 2 has incompatible type "bool"; expected <nothing>  [arg-type]
        foo.sync_parametrized_method("xyz", True)
                                            ^~~~
parametrized-decorator-typing.py:108: error: Incompatible types in "await" (actual type "Callable[[Foo, int], <nothing> | Awaitable[<nothing>]] | Awaitable[<nothing>]",
expected type "Awaitable[Any]")  [misc]
        await foo.async_straight_method()
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:108: error: Too few arguments  [call-arg]
        await foo.async_straight_method()
              ^~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:109: error: Incompatible types in "await" (actual type "Callable[[Foo, int], <nothing> | Awaitable[<nothing>]] | Awaitable[<nothing>]",
expected type "Awaitable[Any]")  [misc]
        await foo.async_straight_method(1)
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:109: error: Argument 1 has incompatible type "int"; expected "Callable[[Foo, int], <nothing> | Awaitable[<nothing>]]"  [arg-type]
        await foo.async_straight_method(1)
                                        ^
parametrized-decorator-typing.py:109: error: Argument 1 has incompatible type "int"; expected "Foo"  [arg-type]
        await foo.async_straight_method(1)
                                        ^
parametrized-decorator-typing.py:111: error: "Awaitable[<nothing>]" not callable  [operator]
        await foo.async_parametrized_method()
              ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:112: error: "Awaitable[<nothing>]" not callable  [operator]
        await foo.async_parametrized_method("xyz", True)
              ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
parametrized-decorator-typing.py:112: error: Argument 1 has incompatible type "str"; expected <nothing>  [arg-type]
        await foo.async_parametrized_method("xyz", True)
                                            ^~~~~
parametrized-decorator-typing.py:112: error: Argument 2 has incompatible type "bool"; expected <nothing>  [arg-type]
        await foo.async_parametrized_method("xyz", True)
                                                   ^~~~
Found 44 errors in 1 file (checked 1 source file)

Solution

  • The solution you proposed has one major flaw: your parametrized_decorator destroys the function signature. You still need a typevar to preserve it.

    First of all, your T does not have any bound, so Awaitable[Something] can also be just T. In your third_party_decorator, wrapper returns T | Awaitable[T] which is just weird - it always returns the given T, it can't become more awaitable. So, we can forget about Awaitable here entirely, and this greatly simplifies the typing.

    Then, your decorator can behave in two different ways: either be a 2nd-order deco (return a decorator) when func is not passed, or return a function otherwise. This can be expressed as an overload. Here's what I suggest (playground):

    import asyncio
    from collections.abc import Awaitable, Callable, Coroutine
    from functools import wraps
    from typing import Any, Literal, ParamSpec, TypeVar, overload
    
    T = TypeVar("T")
    _T1 = TypeVar("_T1")
    P = ParamSpec("P")
    _P1 = ParamSpec("_P1")
    
    
    
    def third_party_decorator(
            a: int | None = None,
            b: str | None = None,
            c: Literal[None] = None,
            d: bool | None = None,
            e: str | None = None,
        ) -> Callable[[Callable[P, T]], Callable[P, T]]:
        def decorator(f: Callable[P, T]) -> Callable[P, T]:
            @wraps(f)
            def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
                print(f"third_party_decorator {f = } {a = } {b = } {c = } {d = } {e = }")
                return f(*args, **kwargs)
            return wrapper
        return decorator
    
    
    @overload
    def parametrized_decorator(f: None = ..., /, **kwargs: Any) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
    @overload
    def parametrized_decorator(f: Callable[P, T], /, **kwargs: Any) -> Callable[P, T]: ...
    def parametrized_decorator(f: Callable[P, T] | None = None, /, **kwargs: Any) -> Callable[[Callable[_P1, _T1]], Callable[_P1, _T1]] | Callable[P, T]:
        def decorator(f: Callable[P, T]) -> Callable[P, T]:
            defaults = {"a": 1, "b": "b", "c": None, "d": True}
            defaults.update(kwargs)
            print(f"parametrized_decorator {f = } {defaults = }")
            decorator = third_party_decorator(**defaults)  # type: ignore[arg-type]
            wrapped = decorator(f)
            return wrapped
    
        if f is None:
            # You can avoid type-ignore here, but this will require an insane if
            # clause with `decorator` body effectively repeated twice
            return decorator  # type: ignore[return-value]
        else:
            return decorator(f)
    

    Note the ugly _T1 and _P1 above? It's because callable is generic itself, and binds too early (P and T should be the same in the left side and in both sides of a new callable). Let's fix this by making "alternative Callable" which is not generic, but provides a generic __call__ only (playground):

    import asyncio
    from collections.abc import Awaitable, Callable, Coroutine
    from functools import wraps
    from typing import Any, Literal, ParamSpec, TypeVar, overload, Protocol
    
    T = TypeVar("T")
    P = ParamSpec("P")
    
    
    class Decorator(Protocol):
        def __call__(self, func: Callable[P, T], /) -> Callable[P, T]: ...
    
    
    def third_party_decorator(
            a: int | None = None,
            b: str | None = None,
            c: Literal[None] = None,
            d: bool | None = None,
            e: str | None = None,
        ) -> Decorator:
        def decorator(f: Callable[P, T]) -> Callable[P, T]:
            @wraps(f)
            def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
                print(f"third_party_decorator {f = } {a = } {b = } {c = } {d = } {e = }")
                return f(*args, **kwargs)
            return wrapper
        return decorator
    
    
    @overload
    def parametrized_decorator(f: None = ..., /, **kwargs: Any) -> Decorator: ...
    @overload
    def parametrized_decorator(f: Callable[P, T], /, **kwargs: Any) -> Callable[P, T]: ...
    def parametrized_decorator(f: Callable[P, T] | None = None, /, **kwargs: Any) -> Decorator | Callable[P, T]:
        def decorator(f: Callable[P, T]) -> Callable[P, T]:
            defaults = {"a": 1, "b": "b", "c": None, "d": True}
            defaults.update(kwargs)
            print(f"parametrized_decorator {f = } {defaults = }")
            decorator = third_party_decorator(**defaults)  # type: ignore[arg-type]
            wrapped = decorator(f)
            return wrapped
    
        if f is None:
            return decorator
        else:
            return decorator(f)
    

    And now you don't have any ignore comments and have proper return types.

    To explain my "destroys the signature", reveal_type(sync_straight_function) or just try sync_straight_function('this', 'is', 'still', 'allowed', 'why?') + 1 with the implementation from your answer and observe no mypy errors.