I am trying to get the typing of an overloaded decorator right that gets wrapped by partial
:
from functools import partial
from typing import Any, Callable, Optional, Union, overload
AnyCallable = Callable[..., Any]
class Wrapped:
def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
pass
@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
...
@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
...
def create_wrapped(
foo: str,
func: Optional[AnyCallable] = None,
*,
bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
def wrapper(func_: AnyCallable) -> Wrapped:
return Wrapped(func_, foo, bar)
if func is None:
return wrapper
return wrapper(func)
baz = partial(create_wrapped, "baz")
@baz
def func_1() -> None:
pass
@baz(bar=False)
def func_2() -> None:
pass
The code is correct, but mypy gives
47: error: "Wrapped" not callable
which indicates that the actual argument types are lost when applying partial
, since @baz(bar=False)
should match the second overload as it's the same as @create_wrapped("baz", bar=False)
, which does work without an issue.
I'm not sure how else I could annotate this, in fact I couldn't come up with any way to make mypy not complain about this, even if I was fine with not having proper types for the decorator since in that case, I'd get an Untyped decorator makes function untyped
error.
mypy does not currently correctly infer the type of a partially applied function: https://github.com/python/mypy/issues/1484.
You can work around it by casting the return of the partial
call to a proper Protocol
.
from functools import partial
from typing import Any, Callable, Optional, Protocol, Union, overload, cast
AnyCallable = Callable[..., Any]
class Wrapped:
def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
pass
@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
...
@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
...
def create_wrapped(
foo: str,
func: Optional[AnyCallable] = None,
*,
bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
def wrapper(func_: AnyCallable) -> Wrapped:
return Wrapped(func_, foo, bar)
if func is None:
return wrapper
return wrapper(func)
class partial_create_wrapped(Protocol):
@overload
def __call__(self, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
...
@overload
def __call__(self, func: AnyCallable) -> Wrapped:
...
baz = cast(partial_create_wrapped, partial(create_wrapped, "baz"))
@baz
def func_1() -> None:
pass
@baz(bar=False)
def func_2() -> None:
pass