Search code examples
pythonmypypython-decoratorspython-typingpyright

Require decorated function to accept argument matching bound `TypeVar` without narrowing to that type


If I define my decorator like this

T = TypeVar('T', bound=Event)

def register1(evtype: Type[T]) -> Callable[[Callable[[T], None]], Callable[[T], None]]:
    def decorator(handler):
        # register handler for event type
        return handler
    return decorator

I get a proper error if I use it on the wrong function:

class A(Event):
    pass

class B(Event):
    pass

@register1(A) # Argument of type "(ev: B) -> None" cannot be assigned to parameter of type "(A) -> None"
def handler1_1(ev: B):
    pass

However, it does not work if I apply the decorator multiple times:

@register1(A) # Argument of type "(B) -> None" cannot be assigned to parameter of type "(A) -> None"
@register1(B)
def handler1_3(ev: A|B):
    pass

I kind of want the decorators to build up a Union of allowed/required argument types.

I think ParamSpec is the way to solve it, but how can I use ParamSpec to not overwrite the argument type but also require that the argument type matches the type that is in the decorator argument?

Using ParamSpec does not result in any type error:

P = ParamSpec("P")

def register2(evtype: Type[T]) -> Callable[[Callable[P, None]], Callable[P, None]]:
    def decorator(handler):
        # ...
        return handler
    return decorator

@register2(A) # This should be an error
def handler2_1(ev: B):
    pass

If I add another TypeVar and use a Union it does work for the double-decorated and even triple decorated function, but not or single decorated functions.

T2 = TypeVar('T2')

def register3(evtype: Type[T]) -> Callable[[Callable[[Union[T,T2]], None]], Callable[[Union[T,T2]], None]]:
    def decorator(handler):
        # ...
        return handler
    return decorator

# Expected error:
@register3(A) # Argument of type "(ev: B) -> None" cannot be assigned to parameter of type "(A | T2@register3) -> None"
def handler3_1(ev: B):
    pass

# Wrong error:
@register3(A) # Argument of type "(ev: A) -> None" cannot be assigned to parameter of type "(A | T2@register3) -> None"
def handler3_2(ev: A):
    pass

# Works fine
@register3(A)
@register3(B)
def handler3_3(ev: A|B):
    pass

While writing this question, I came the solution closer and closer. And I will provide my own solution in an Answer.

However, I'm interested if there are better ways to solve this.


Solution

  • Transported from https://github.com/microsoft/pyright/discussions/7404 --

    from __future__ import annotations
    
    from typing import Any, Callable, Protocol, TypeVar, overload
    
    T_co = TypeVar("T_co", covariant=True)
    
    T0 = TypeVar("T0")
    T1 = TypeVar('T1')
    
    class RegisterResult(Protocol[T_co]):
        @overload
        def __call__(self, handler: Callable[[T_co | T1], None]) -> Callable[[T_co | T1], None]: ...
    
        @overload
        def __call__(self, handler: Callable[[T_co], None]) -> Callable[[T_co], None]: ...
    
    def register(evtype: type[T0]) -> RegisterResult[T0]:
        def decorator(handler: Any) -> Any:
            return handler
        
        return decorator
    
    class A: ...
    class B: ...
    class C: ...
    
    @register(A)
    def handle_a(ev: A): ...
    
    handle_a(A())
    
    @register(A)
    @register(B)
    # ... Should support infinite amount of @register calls
    def handle_ab(ev: A|B): ... 
    
    handle_ab(A())
    handle_ab(B())
    
    #Expected error cases because of wrong types:
    @register(A)
    def handle_b(ev: B): ... 
    handle_a(B())
    handle_ab(C())
    

    Note that the code above is working with latest Pyright (v1.1.353) and may cease working in a future version based on how Pyright deals with compatibilities of overloaded functions. And it is NOT fully working with latest Mypy (v1.9.0) as I checked.