Search code examples
pythongenericsmypypython-typing

How to distinguish between Base Class and Derived Class in generics using Typing and Mypy


Consider the following code:

from typing import TypeVar
import dataclasses


@dataclasses.dataclass
class A:
    pass


@dataclasses.dataclass
class B(A):
    pass


T = TypeVar("T", A, B)


def fun(
    x1: T,
    x2: T,
) -> int:
    if type(x1) != type(x2):
        raise TypeError("must be same type!")

    if type(x1) == A:
        return 5

    elif type(x1) == B:
        return 10
    else:
        raise TypeError("Type not handled")


fun(x1=A(), x2=A())  # OK
fun(x1=B(), x2=B())  # OK
fun(x1=B(), x2=A())  # Will throw TypeError, how can I get mypy to say this is an error?
fun(x1=A(), x2=B())  # Will throw TypeError, how can I get mypy to say this is an error?

Mypy is not seeing any problem here. It seems like it is always interpreting the passed object as a base class object of type A.

Is there a way to make the generic even more strict in the sense that it is sensitive to the exact class type? Such that if x1 is of type B, then also x2 must be exactly of type B? If x1 is of type A then also x2 must be exactly of type A?


Solution

  • This was a fun question - at first I considered solving it in the following way:

    from typing import overload
    import dataclasses
    
    
    @dataclasses.dataclass
    class A:
        pass
    
    
    @dataclasses.dataclass
    class B(A):
        pass
    
    
    @overload
    def fun(x1: B, x2: B) -> int:
        ...
    
    
    @overload
    def fun(x1: A, x2: A) -> int:
        ...
    
    
    def fun(
        x1: A | B,
        x2: A | B,
    ) -> int:
        if type(x1) != type(x2):
            raise TypeError("must be same type!")
    
        if type(x1) == A:
            return 5
    
        elif type(x1) == B:
            return 10
        else:
            raise TypeError("Type not handled")
    
    
    fun(x1=A(), x2=A())
    fun(x1=B(), x2=B())
    fun(x1=B(), x2=A())
    fun(x1=A(), x2=B())
    

    I thought this might be a quirk of the way TypeVar works initially, but I discovered that even if we specify in an overload, that it must be A, A or B, B then it still won't raise an error on the final two lines. The last two just use the overload A, A, because A, B is still a subtype of A, A. Python does not distinguish at all between direct instances and subtypes - you could enforce a structural type with a Protocol, so long as there was a structural difference between A and B.

    Even if you made the function arg a list[A], B's would still be valid in the list for this reason.

    If you're trying to get B to pick to up all the attributes of A, and have them be distinct types, I would instead do it this way, with A being now a hidden base class, and A2 exposed to an end user:

    from typing import TypeVar
    import dataclasses
    
    
    @dataclasses.dataclass
    class A:
        pass
    
    
    @dataclasses.dataclass
    class B(A):
        pass
    
    @dataclasses.dataclass
    class A2(A):
        # Note, this class would be empty in practice as well
        pass
    
    
    
    T = TypeVar("T", A2, B)
    
    def fun(
        x1: T,
        x2: T,
    ) -> int:
        if type(x1) != type(x2):
            raise TypeError("must be same type!")
    
        if type(x1) == A2:
            return 5
    
        elif type(x1) == B:
            return 10
        else:
            raise TypeError("Type not handled")
    
    
    fun(x1=A2(), x2=A2())  # OK
    fun(x1=B(), x2=B())  # OK
    fun(x1=B(), x2=A2())  # Will throw TypeError, how can I get mypy to say this is an error?
    fun(x1=A2(), x2=B())  # Will throw TypeError, how can I get mypy to say this is an error?
    

    Hope this is useful!