If I define my decorator like this
T = TypeVar('T', bound=Event)
def register1(evtype: Type[T]) -> Callable[[Callable[[T], None]], Callable[[T], None]]:
def decorator(handler):
# register handler for event type
return handler
return decorator
I get a proper error if I use it on the wrong function:
class A(Event):
pass
class B(Event):
pass
@register1(A) # Argument of type "(ev: B) -> None" cannot be assigned to parameter of type "(A) -> None"
def handler1_1(ev: B):
pass
However, it does not work if I apply the decorator multiple times:
@register1(A) # Argument of type "(B) -> None" cannot be assigned to parameter of type "(A) -> None"
@register1(B)
def handler1_3(ev: A|B):
pass
I kind of want the decorators to build up a Union
of allowed/required argument types.
I think ParamSpec
is the way to solve it, but how can I use ParamSpec
to not overwrite the argument type but also require that the argument type matches the type that is in the decorator argument?
Using ParamSpec
does not result in any type error:
P = ParamSpec("P")
def register2(evtype: Type[T]) -> Callable[[Callable[P, None]], Callable[P, None]]:
def decorator(handler):
# ...
return handler
return decorator
@register2(A) # This should be an error
def handler2_1(ev: B):
pass
If I add another TypeVar
and use a Union
it does work for the double-decorated and even triple decorated function, but not or single decorated functions.
T2 = TypeVar('T2')
def register3(evtype: Type[T]) -> Callable[[Callable[[Union[T,T2]], None]], Callable[[Union[T,T2]], None]]:
def decorator(handler):
# ...
return handler
return decorator
# Expected error:
@register3(A) # Argument of type "(ev: B) -> None" cannot be assigned to parameter of type "(A | T2@register3) -> None"
def handler3_1(ev: B):
pass
# Wrong error:
@register3(A) # Argument of type "(ev: A) -> None" cannot be assigned to parameter of type "(A | T2@register3) -> None"
def handler3_2(ev: A):
pass
# Works fine
@register3(A)
@register3(B)
def handler3_3(ev: A|B):
pass
While writing this question, I came the solution closer and closer. And I will provide my own solution in an Answer.
However, I'm interested if there are better ways to solve this.
Transported from https://github.com/microsoft/pyright/discussions/7404 --
from __future__ import annotations
from typing import Any, Callable, Protocol, TypeVar, overload
T_co = TypeVar("T_co", covariant=True)
T0 = TypeVar("T0")
T1 = TypeVar('T1')
class RegisterResult(Protocol[T_co]):
@overload
def __call__(self, handler: Callable[[T_co | T1], None]) -> Callable[[T_co | T1], None]: ...
@overload
def __call__(self, handler: Callable[[T_co], None]) -> Callable[[T_co], None]: ...
def register(evtype: type[T0]) -> RegisterResult[T0]:
def decorator(handler: Any) -> Any:
return handler
return decorator
class A: ...
class B: ...
class C: ...
@register(A)
def handle_a(ev: A): ...
handle_a(A())
@register(A)
@register(B)
# ... Should support infinite amount of @register calls
def handle_ab(ev: A|B): ...
handle_ab(A())
handle_ab(B())
#Expected error cases because of wrong types:
@register(A)
def handle_b(ev: B): ...
handle_a(B())
handle_ab(C())
Note that the code above is working with latest Pyright (v1.1.353) and may cease working in a future version based on how Pyright deals with compatibilities of overloaded functions. And it is NOT fully working with latest Mypy (v1.9.0) as I checked.