Search code examples
pythonmypypython-typing

Typing an overloaded decorator wrapped in partial


I am trying to get the typing of an overloaded decorator right that gets wrapped by partial:

from functools import partial
from typing import Any, Callable, Optional, Union, overload


AnyCallable = Callable[..., Any]


class Wrapped:
    def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
        pass


@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
    ...


@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
    ...


def create_wrapped(
    foo: str,
    func: Optional[AnyCallable] = None,
    *,
    bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
    def wrapper(func_: AnyCallable) -> Wrapped:
        return Wrapped(func_, foo, bar)

    if func is None:
        return wrapper
    return wrapper(func)


baz = partial(create_wrapped, "baz")


@baz
def func_1() -> None:
    pass


@baz(bar=False)
def func_2() -> None:
    pass

The code is correct, but mypy gives

47: error: "Wrapped" not callable

which indicates that the actual argument types are lost when applying partial, since @baz(bar=False) should match the second overload as it's the same as @create_wrapped("baz", bar=False), which does work without an issue.

I'm not sure how else I could annotate this, in fact I couldn't come up with any way to make mypy not complain about this, even if I was fine with not having proper types for the decorator since in that case, I'd get an Untyped decorator makes function untyped error.


Solution

  • mypy does not currently correctly infer the type of a partially applied function: https://github.com/python/mypy/issues/1484.

    You can work around it by casting the return of the partial call to a proper Protocol.

    from functools import partial
    from typing import Any, Callable, Optional, Protocol, Union, overload, cast
    
    
    AnyCallable = Callable[..., Any]
    
    
    class Wrapped:
        def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
            pass
    
    
    @overload
    def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
        ...
    
    
    @overload
    def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
        ...
    
    
    def create_wrapped(
        foo: str,
        func: Optional[AnyCallable] = None,
        *,
        bar: bool = True,
    ) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
        def wrapper(func_: AnyCallable) -> Wrapped:
            return Wrapped(func_, foo, bar)
    
        if func is None:
            return wrapper
    
        return wrapper(func)
    
    
    class partial_create_wrapped(Protocol):
        @overload
        def __call__(self, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
            ...
    
        @overload
        def __call__(self, func: AnyCallable) -> Wrapped:
            ...
    
    
    baz = cast(partial_create_wrapped, partial(create_wrapped, "baz"))
    
    
    @baz
    def func_1() -> None:
        pass
    
    
    @baz(bar=False)
    def func_2() -> None:
        pass