Search code examples
pythongenericspython-typing

How to test __annotations__ member has same type as generic class parameter?


I want to be able to select only annotated members that have the same type as the parameter of generic in this base class:

from typing import Generic, TypeVar


T = TypeVar("T")

class MarkFinder(Generic[T]):
    def __init_subclass__(cls, **kwargs):
        cls.marked = tuple(
            name for name, annotated_type in cls.__annotations__.items() 
            if some_condition(annotated_type)
        )

So that if I inherit:

T2 = TypeVar("T2")

class Inheritor(MarkFinder[T2]):
    a: T2
    b: int
    c: T2

Then Inheritor.marked should just be ('a', 'c').

I have tried to replace some_condition(annotated_type) with:

cls.__parameters__[0] is annotated_type or cls.__parameters__[0] == annotated_type

but even though they have the same names, the types are not equal.

What is the correct condition? Or is this impossible?


Solution

  • This is possible, but with a few caveats. The crux of the approach as well as some of the things to keep in mind are explained in the following post, so I suggest you read it first:

    Access type argument in any specific subclass of user-defined Generic[T] class

    The TL;DR is to grab the the type argument from the original base class (which will be a generic alias type) from __orig_bases__ and compare the annotations against that (for identity).

    From the way you phrased your question, I assume you only want this to apply to type variables and not to specific type arguments. Here is how you could do it:

    from typing import Any, ClassVar, Generic, TypeVar, get_args, get_origin
    
    
    T = TypeVar("T")
    
    
    class MarkFinder(Generic[T]):
        marked: ClassVar[tuple[str, ...]] = ()
    
        @classmethod
        def __init_subclass__(cls, **kwargs: Any) -> None:
            super().__init_subclass__(**kwargs)
            for base in cls.__orig_bases__:  # type: ignore[attr-defined]
                origin = get_origin(base)
                if origin is None or not issubclass(origin, MarkFinder):
                    continue
                type_arg = get_args(base)[0]
                if not isinstance(type_arg, TypeVar):
                    return  # do not touch non-generic subclasses
                cls.marked += tuple(
                    name for name, annotation in cls.__annotations__.items()
                    if annotation is type_arg
                )
    

    Usage demo:

    T2 = TypeVar("T2")
    
    
    class Child(MarkFinder[T2]):
        a: T2
        b: int
        c: T2
    
    
    T3 = TypeVar("T3")
    
    
    class GrandChild(Child[T3]):
        d: T3
        e: str
    
    
    class SpecificDescendant1(GrandChild[int]):
        f: int
    
    
    class SpecificDescendant2(GrandChild[str]):
        f: float
    
    
    print(Child.marked)                # ('a', 'c')
    print(GrandChild.marked)           # ('a', 'c', 'd')
    print(SpecificDescendant1.marked)  # ('a', 'c', 'd')
    print(SpecificDescendant2.marked)  # ('a', 'c', 'd')
    

    The check for the type_arg not being a TypeVar instance is important. Without it, that SpecificDescendant1 subclass would also have 'f' in its marked tuple. (If that is what you want, just remove that check from the base class' __init_subclass__ method.)