MRE: https://mypy-play.net/?mypy=latest&python=3.10&gist=263dfa4914d0317291638957e51e8700&flags=strict
from typing import Callable, TypeVar, Type, Protocol
from typing_extensions import Self
_CT = TypeVar("_CT")
_T = TypeVar("_T")
class LtConverter(Protocol):
def __new__(cls: Type[_T], x: _CT) -> _T: ...
def __lt__(self, other: Self) -> bool: ...
def lt_through(conv: Type[LtConverter]) -> Callable[[_CT, _CT], bool]:
def comparator(op: Callable[[LtConverter, LtConverter], bool]) -> Callable[[_CT, _CT], bool]:
def method(self: _CT, other: _CT) -> bool:
return op(conv(self), conv(other))
return method
return comparator(lambda x, y: x < y)
In the MRE, lt_through(conv)
returns a function, which can be used as a method for a class, such that the "less than" operator of the class compares the class after been casted to type conv
, which is any class supporting the protocol to casting and the less than operation with itself.
mypy gives the following error message:
main.py:18: error: Argument 1 to "LtConverter" has incompatible type "_CT"; expected "_CT" [arg-type]
I am vaguely confused by this message. Why is _CT
considered to be different from _CT
? Why does this error occur for conv(self)
but not for conv(other)
(which can be seen by inspecting the char number in the verbose output)?
Your code is almost correct - mypy's reporting just didn't do very well to allow easily fixing the remaining issues. I'll answer these two first:
I am vaguely confused by this message. Why is
_CT
considered to be different from_CT
?
The _CT
in def __new__(cls: Type[_T], x: _CT) -> _T: ...
is not actually the same _CT
as the one found in def method(self: _CT, other: _CT) -> bool:
- you've just used the same name _CT
in two different type variable contexts. Rewriting this implementation, this is what mypy actually sees you're doing:
_CT1 = TypeVar("_CT1")
_CT2 = TypeVar("_CT2")
...
class LtConverter(Protocol):
def __new__(cls: Type[_T], x: _CT1) -> _T:
...
...
def method(self: _CT2, other: _CT2) -> bool:
return op(conv(self), conv(other)) # mypy: Argument 1 to "LtConverter" has incompatible type "_CT2"; expected "_CT1" [arg-type]
Why does this error occur for conv(self) but not for conv(other) (which can be seen by inspecting the char number in the verbose output)?
Very unfortunately, this is because mypy sometimes can't report multiple errors on one line. You'll get the second one appear if you reformat the code a bit, like the following:
def method(self: _CT, other: _CT) -> bool:
return op(
conv(self), # mypy: Argument 1 to "LtConverter" has incompatible type "_CT"; expected "_CT" [arg-type]
conv(other), # mypy: Argument 1 to "LtConverter" has incompatible type "_CT"; expected "_CT" [arg-type]
)
If I understand your intention correctly, to fix this, you just have to change LtConverter.__new__
to one of the following:
def __new__(cls: Type[_T], x: typing.Any) -> _T: ...
def __new__(cls, x: typing.Any) -> Self: ...
This is based on the assumptions that
def method(self: _CT, other: _CT):
means I don't care what self
and other
are, as long as they're up-castable to the same type, andLtConverter
is capable of being instantiated with any object to provide a valid __lt__
comparison.In practice, your intended implementation might be a bit stricter than this, in which case you should consider putting a type variable bound on _CT
, parameterising LtConverter
with a _CT
type, and/or replacing LtConverter.__new__::x: typing.Any
with x: _CT
. Here's one possible version with stricter typing:
from typing import Callable, TypeVar, Type, Protocol
from typing_extensions import Self
_CT = TypeVar("_CT")
_T = TypeVar("_T")
_ConverteeT = TypeVar("_ConverteeT", covariant=True)
class LtConverter(Protocol[_ConverteeT]):
def __new__(cls: Type[_T], x: _ConverteeT) -> _T: ...
def __lt__(self, other: Self) -> bool: ...
def lt_through(conv: Type[LtConverter[_CT]]) -> Callable[[_CT, _CT], bool]:
def comparator(op: Callable[[LtConverter[_CT], LtConverter[_CT]], bool]) -> Callable[[_CT, _CT], bool]:
def method(self: _CT, other: _CT) -> bool:
return op(conv(self), conv(other))
return method
return comparator(lambda x, y: x < y)