Search code examples
pythongenericsmypy

Typing function when decorator change generic return type


This is similar to Typing function when decorator change return type but this time with a generic return type:

from typing import Generic, TypeVar, Generic, Callable, Any, cast

T = TypeVar('T')

class Container(Generic[T]):
    def __init__(self, item: T):
        self.item: T = item

def my_decorator(f: Callable[..., T]) -> Callable[..., Container[T]]:

    def wrapper(*args: Any, **kwargs: Any) -> Container[T]:
        return Container(f(*args, **kwargs))

    return cast(Callable[..., Container[T]], wrapper)

@my_decorator
def my_func(i: int, s: str) -> bool: ...

reveal_type(my_func) # Revealed type is 'def (*Any, **Any) -> file.Container[builtins.bool*]

Which mypy sorcery is required to keep the argument types intact for my_func?

Using typing.Protocol looks promising, but I don't see how to make it work.


Solution

  • Just a quick update on the accepted answer and the final state of PEP612, the current form is the following:

    def my_decorator(f: Callable[P, R]) -> Callable[P, Container[R]]:
    
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> Container[R]:
            return Container(f(*args, **kwargs))
    
        return wrapper
    

    It seems using both explicit P.args and P.kwargs types in the decorated function is mandatory. The complete code example would be:

    from typing import (
        Callable,
        Generic,
        ParamSpec,
        TypeVar,
    )
    
    T = TypeVar("T")
    
    class Container(Generic[T]):
        def __init__(self, item: T):
            self.item: T = item
    
    P = ParamSpec("P")
    R = TypeVar("R")
    
    
    def my_decorator(f: Callable[P, R]) -> Callable[P, Container[R]]:
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> Container[R]:
            return Container(f(*args, **kwargs))
        return wrapper
    
    @my_decorator
    def my_func(i: int, s: str) -> bool:
        ...
    

    Which gives the expected typing for my_func:

    (function) def my_func(
        i: int,
        s: str
    ) -> Container[bool]