I have a decorator that is taking a function as argument and returns a function with the same signature. The first argument to this function must have a foo
attribute, and the decorator performs some side-effect based on this value.
I'm trying to hint types for all this with a mix of Protocol
and ParamSpec
(I'm using Python 3.10), but apparently I'm doing it wrong. Below is a toy implementation that fails.
import functools
from typing import Callable, Concatenate, ParamSpec, Protocol, TypeVar
class HasFoo(Protocol):
foo: int
P = ParamSpec("P")
T = TypeVar("T")
def print_foo(
f: Callable[Concatenate[HasFoo, P], T]
) -> Callable[Concatenate[HasFoo, P], T]:
"""This is my decorator.
It prints `x.foo` where `x` is the first argument of the
decorated function.
"""
@functools.wraps(f)
def wrapped(has_foo: HasFoo, *args: P.args, **kwargs: P.kwargs) -> T:
print(has_foo.foo)
return f(has_foo, *args, **kwargs)
return wrapped
class ReallyHasFoo:
"""Some actual implementation of `HasFoo`."""
def __init__(self) -> None:
self.foo = 0
@print_foo
def f(has_foo: ReallyHasFoo) -> None:
pass
With this example, mypy
complains as follows:
test.py:37: error: Argument 1 to "print_foo" has incompatible type "Callable[[ReallyHasFoo], None]"; expected "Callable[[HasFoo], None]"
test.py:37: note: This may be because "f" has arguments named: "has_foo"
Any idea how to sort this out?
Thanks to @SUTerliakov's answer which explains the problem but does not provide a solution, I've been able to find one.
The issue is here:
def print_foo(
f: Callable[Concatenate[HasFoo, P], T]
) -> Callable[Concatenate[HasFoo, P], T]:
This signature assumes f
can take any parameter that has a foo
attribute, which is not true. As pointed by @SUTerliakov, the body of f
could be:
assert isinstance(has_foo, ReallyHasFoo)
The solution is to use a TypeVar
instead of HasFoo
as shown below:
HasFooSubtype = TypeVar("HasFooSubtype", bound=HasFoo)
def print_foo(
f: Callable[Concatenate[HasFooSubtype, P], T]
) -> Callable[Concatenate[HasFooSubtype, P], T]:
@functools.wraps(f)
def wrapped(
has_foo: HasFooSubtype, *args: P.args, **kwargs: P.kwargs
) -> T:
print(has_foo.foo)
return f(has_foo, *args, **kwargs)
return wrapped
A bit clumsy but it does what it does and mypy does not complain anymore.