Search code examples
pythongenericspython-typingbase-class

Access type argument in any specific subclass of user-defined Generic[T] class


Context

Say we want to define a custom generic (base) class that inherits from typing.Generic.

For the sake of simplicity, we want it to be parameterized by a single type variable T. So the class definition starts like this:

from typing import Generic, TypeVar

T = TypeVar("T")

class GenericBase(Generic[T]):
    ...

Question

Is there a way to access the type argument T in any specific subclass of GenericBase?

The solution should be universal enough to work in a subclass with additional bases besides GenericBase and be independent of instantiation (i.e. work on the class level).

The desired outcome is a class-method like this:

class GenericBase(Generic[T]):

    @classmethod
    def get_type_arg(cls) -> Type[T]:
        ...

Usage

class Foo:
    pass

class Bar:
    pass

class Specific(Foo, GenericBase[str], Bar):
    pass

print(Specific.get_type_arg())

The output should be <class 'str'>.

Bonus

It would be nice if all relevant type annotations were made such that static type checkers could correctly infer the specific class returned by get_type_arg.

Related questions


Solution

  • TL;DR

    Grab the GenericBase from the subclass' __orig_bases__ tuple, pass it to typing.get_args, grab the first element from the tuple it returns, and make sure what you have is a concrete type.

    1) Starting with get_args

    As pointed out in this post, the typing module for Python 3.8+ provides the get_args function. It is convenient because given a specialization of a generic type, get_args returns its type arguments (as a tuple).

    Demonstration:

    from typing import Generic, TypeVar, get_args
    
    T = TypeVar("T")
    
    class GenericBase(Generic[T]):
        pass
    
    print(get_args(GenericBase[int]))
    

    Output:

    (<class 'int'>,)
    

    This means that once we have access to a specialized GenericBase type, we can easily extract its type argument.

    2) Continuing with __orig_bases__

    As further pointed out in the aforementioned post, there is this handy little class attribute __orig_bases__ that is set by the type metaclass when a new class is created. It is mentioned here in PEP 560, but is otherwise hardly documented.

    This attribute contains (as the name suggests) the original bases as they were passed to the metaclass constructor in the form of a tuple. This distinguishes it from __bases__, which contains the already resolved bases as returned by types.resolve_bases.

    Demonstration:

    from typing import Generic, TypeVar
    
    T = TypeVar("T")
    
    class GenericBase(Generic[T]):
        pass
    
    class Specific(GenericBase[int]):
        pass
    
    print(Specific.__bases__)
    print(Specific.__orig_bases__)
    

    Output:

    (<class '__main__.GenericBase'>,)
    (__main__.GenericBase[int],)
    

    We are interested in the original base because that is the specialization of our generic class, meaning it is the one that "knows" about the type argument (int in this example), whereas the resolved base class is just an instance of type.

    3) Simplistic solution

    If we put these two together, we can quickly construct a simplistic solution like this:

    from typing import Generic, TypeVar, get_args
    
    T = TypeVar("T")
    
    class GenericBase(Generic[T]):
        @classmethod
        def get_type_arg_simple(cls):
            return get_args(cls.__orig_bases__[0])[0]
    
    class Specific(GenericBase[int]):
        pass
    
    print(Specific.get_type_arg_simple())
    

    Output:

    <class 'int'>
    

    But this will break as soon as we introduce another base class on top of our GenericBase.

    from typing import Generic, TypeVar, get_args
    
    T = TypeVar("T")
    
    class GenericBase(Generic[T]):
        @classmethod
        def get_type_arg_simple(cls):
            return get_args(cls.__orig_bases__[0])[0]
    
    class Mixin:
        pass
    
    class Specific(Mixin, GenericBase[int]):
        pass
    
    print(Specific.get_type_arg_simple())
    

    Output:

    Traceback (most recent call last):
      ...
        return get_args(cls.__orig_bases__[0])[0]
    IndexError: tuple index out of range
    

    This happens because cls.__orig_bases__[0] now happens to be Mixin, which is not a parameterized type, so get_args returns an empty tuple ().

    So what we need is a way to unambiguously identify the GenericBase from the __orig_bases__ tuple.

    4) Identifying with get_origin

    Just like typing.get_args gives us the type arguments for a generic type, typing.get_origin gives us the unspecified version of a generic type.

    Demonstration:

    from typing import Generic, TypeVar, get_origin
    
    T = TypeVar("T")
    
    class GenericBase(Generic[T]):
        pass
    
    print(get_origin(GenericBase[int]))
    print(get_origin(GenericBase[str]) is GenericBase)
    

    Output:

    <class '__main__.GenericBase'>
    True
    

    5) Putting them together

    With these components, we can now write a function get_type_arg that takes a class as an argument and -- if that class is specialized form of our GenericBase -- returns its type argument:

    from typing import Generic, TypeVar, get_origin, get_args
    
    T = TypeVar("T")
    
    class GenericBase(Generic[T]):
        pass
    
    class Specific(GenericBase[int]):
        pass
    
    def get_type_arg(cls):
        for base in cls.__orig_bases__:
            origin = get_origin(base)
            if origin is None or not issubclass(origin, GenericBase):
                continue
            return get_args(base)[0]
    
    print(get_type_arg(Specific))
    

    Output:

    <class 'int'>
    

    Now all that is left to do is embed this directly as a class-method of GenericBase, optimize it a little bit and fix the type annotations.

    One thing we can do to optimize this, is only run this algorithm only once for any given subclass of GenericBase, namely when it is defined, and then save the type in a class-attribute. Since the type argument presumably never changes for a specific class, there is no need to compute this every time we want to access the type argument. To accomplish this, we can hook into __init_subclass__ and do our loop there.

    We should also define a proper response for when get_type_arg is called on a (unspecified) generic class. An AttributeError seems appropriate.

    6) Full working example

    from typing import Any, Generic, Optional, Type, TypeVar, get_args, get_origin
    
    
    # The `GenericBase` must be parameterized with exactly one type variable.
    T = TypeVar("T")
    
    
    class GenericBase(Generic[T]):
        _type_arg: Optional[Type[T]] = None  # set in specified subclasses
    
        @classmethod
        def __init_subclass__(cls, **kwargs: Any) -> None:
            """
            Initializes a subclass of `GenericBase`.
    
            Identifies the specified `GenericBase` among all base classes and
            saves the provided type argument in the `_type_arg` class attribute
            """
            super().__init_subclass__(**kwargs)
            for base in cls.__orig_bases__:  # type: ignore[attr-defined]
                origin = get_origin(base)
                if origin is None or not issubclass(origin, GenericBase):
                    continue
                type_arg = get_args(base)[0]
                # Do not set the attribute for GENERIC subclasses!
                if not isinstance(type_arg, TypeVar):
                    cls._type_arg = type_arg
                    return
    
        @classmethod
        def get_type_arg(cls) -> Type[T]:
            if cls._type_arg is None:
                raise AttributeError(
                    f"{cls.__name__} is generic; type argument unspecified"
                )
            return cls._type_arg
    
    
    def demo_a() -> None:
        class SpecificA(GenericBase[int]):
            pass
    
        print(SpecificA.get_type_arg())
    
    
    def demo_b() -> None:
        class Foo:
            pass
    
        class Bar:
            pass
    
        class GenericSubclass(GenericBase[T]):
            pass
    
        class SpecificB(Foo, GenericSubclass[str], Bar):
            pass
    
        type_b = SpecificB.get_type_arg()
        print(type_b)
        e = type_b.lower("E")  # static type checkers correctly infer `str` type
        assert e == "e"
    
    
    if __name__ == '__main__':
        demo_a()
        demo_b()
    

    Output:

    <class 'int'>
    <class 'str'>
    

    An IDE like PyCharm even provides the correct auto-suggestions for whatever type is returned by get_type_arg, which is really nice. 🎉

    7) Caveats

    • The __orig_bases__ attribute is not well documented. I am not sure it should be considered entirely stable. Although it doesn't appear to be "just an implementation detail" either. I would suggest keeping an eye on that.
    • mypy seems to agree with this caution and raises a no attribute error in the place where you access __orig_bases__. Thus a type: ignore was placed in that line.
    • The entire setup is for one single type parameter for our generic class. It can be adapted relatively easily to multiple parameters, though annotations for type checkers might become more tricky.
    • This method does not work when called directly from a specialized GenericBase class, i.e. GenericBase[str].get_type_arg(). But for that one just needs to call typing.get_args on it as shown in the very beginning.