Consider the following code:
from typing import TypeVar
import dataclasses
@dataclasses.dataclass
class A:
pass
@dataclasses.dataclass
class B(A):
pass
T = TypeVar("T", A, B)
def fun(
x1: T,
x2: T,
) -> int:
if type(x1) != type(x2):
raise TypeError("must be same type!")
if type(x1) == A:
return 5
elif type(x1) == B:
return 10
else:
raise TypeError("Type not handled")
fun(x1=A(), x2=A()) # OK
fun(x1=B(), x2=B()) # OK
fun(x1=B(), x2=A()) # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A(), x2=B()) # Will throw TypeError, how can I get mypy to say this is an error?
Mypy is not seeing any problem here. It seems like it is always interpreting the passed object as a base class object of type A
.
Is there a way to make the generic even more strict in the sense that it is sensitive to the exact class type? Such that if x1
is of type B
, then also x2
must be exactly of type B
? If x1
is of type A
then also x2
must be exactly of type A
?
This was a fun question - at first I considered solving it in the following way:
from typing import overload
import dataclasses
@dataclasses.dataclass
class A:
pass
@dataclasses.dataclass
class B(A):
pass
@overload
def fun(x1: B, x2: B) -> int:
...
@overload
def fun(x1: A, x2: A) -> int:
...
def fun(
x1: A | B,
x2: A | B,
) -> int:
if type(x1) != type(x2):
raise TypeError("must be same type!")
if type(x1) == A:
return 5
elif type(x1) == B:
return 10
else:
raise TypeError("Type not handled")
fun(x1=A(), x2=A())
fun(x1=B(), x2=B())
fun(x1=B(), x2=A())
fun(x1=A(), x2=B())
I thought this might be a quirk of the way TypeVar
works initially, but I discovered that even if we specify in an overload, that it must be A, A
or B, B
then it still won't raise an error on the final two lines. The last two just use the overload A, A
, because A, B
is still a subtype of A, A
. Python does not distinguish at all between direct instances and subtypes - you could enforce a structural type with a Protocol
, so long as there was a structural difference between A and B.
Even if you made the function arg a list[A]
, B's would still be valid in the list for this reason.
If you're trying to get B to pick to up all the attributes of A, and have them be distinct types, I would instead do it this way, with A being now a hidden base class, and A2 exposed to an end user:
from typing import TypeVar
import dataclasses
@dataclasses.dataclass
class A:
pass
@dataclasses.dataclass
class B(A):
pass
@dataclasses.dataclass
class A2(A):
# Note, this class would be empty in practice as well
pass
T = TypeVar("T", A2, B)
def fun(
x1: T,
x2: T,
) -> int:
if type(x1) != type(x2):
raise TypeError("must be same type!")
if type(x1) == A2:
return 5
elif type(x1) == B:
return 10
else:
raise TypeError("Type not handled")
fun(x1=A2(), x2=A2()) # OK
fun(x1=B(), x2=B()) # OK
fun(x1=B(), x2=A2()) # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A2(), x2=B()) # Will throw TypeError, how can I get mypy to say this is an error?
Hope this is useful!