Search code examples
pythonpython-typingmypy

Get Type Argument of Arbitrarily High Generic Parent Class at Runtime


Given this:

from typing import Generic, TypeVar

T = TypeVar('T')

class Parent(Generic[T]):
    pass

I can get int from Parent[int] using typing.get_args(Parent[int])[0].

The problem becomes a bit more complicated with the following:

class Child1(Parent[int]):
    pass

class Child2(Child1):
    pass

To support an arbitrarily long inheritance hierarchy, I made the below solution:

import typing
from dataclasses import dataclass

@dataclass(frozen=True)
class Found:
    value: Any

def get_parent_type_parameter(child: type) -> Optional[Found]:
    for base in child.mro():
        # If no base classes of `base` are generic, then `__orig_bases__` is nonexistent causing an `AttributeError`.
        # Instead, we want to skip iteration.
        for generic_base in getattr(base, "__orig_bases__", ()):
            if typing.get_origin(generic_base) is Parent:
                [type_argument] = typing.get_args(generic_base)

                # Return `Found(type_argument)` instead of `type_argument` to differentiate between `Parent[None]` 
                # as a base class and `Parent` not appearing as a base class.
                return Found(type_argument)

    return None

such that get_parent_type_parameter(Child2) returns int. I am only interested in the type argument of one particular base class (Parent), so I've hardcoded that class into get_parent_type_parameter and ignore any other base classes.

But my above solution breaks down with chains like this:

class Child3(Parent[T], Generic[T]):
    pass

where get_parent_type_parameter(Child3[int]) returns T instead of int.

While any answers that tackle Child3 are already great, being able to deal with situations like Child4 would be even better:

from typing import Sequence

class Child4(Parent[Sequence[T]], Generic[T]):
    pass

so get_parent_type_parameter(Child4[int]) returns Sequence[int].

Is there a more robust way of accessing the type argument of a class X at runtime given an annotation A where issubclass(typing.get_origin(A), X) is True?

Why I need this:

Recent Python HTTP frameworks generate the endpoint documentation (and response schema) from the function's annotated return type. For example:

app = ...

@dataclass
class Data:
    hello: str

@app.get("/")
def hello() -> Data:
    return Data(hello="world")

I am trying to expand this to account for status code and other non-body components:

@dataclass
class Error:
    detail: str

class ClientResponse(Generic[T]):
    status_code: ClassVar[int]
    body: T

class OkResponse(ClientResponse[Data]):
    status_code: ClassVar[int] = 200

class BadResponse(ClientResponse[Error]):
    status_code: ClassVar[int] = 400

@app.get("/")
def hello() -> Union[OkResponse, BadResponse]:
    if random.randint(1, 2) == 1:
        return OkResponse(Data(hello="world"))

    return BadResponse(Error(detail="a_custom_error_label"))

To generate the OpenAPI documentation, my framework would evaluate get_parent_type_parameter(E) (with ClientResponse hardcoded as the parent in get_parent_type_parameter) on each E within the Union after inspecting the annotated return type of the function passed to app.get. So E would be OkResponse first resulting in Data. Then it would be ErrorResponse, resulting in Error. My framework then iterates through the __annotations__ of each of the body types and generates the response schema in the documentation for the client.


Solution

  • The following approach is based on __class_getitem__ and __init_subclass__. It might serve your use case, but it has some severe limitations (see below), so use at your own judgement.

    from __future__ import annotations
    
    from typing import Generic, Sequence, TypeVar
    
    
    T = TypeVar('T')
    
    
    NO_ARG = object()
    
    
    class Parent(Generic[T]):
        arg = NO_ARG  # using `arg` to store the current type argument
    
        def __class_getitem__(cls, key):
            if cls.arg is NO_ARG or cls.arg is T:
                cls.arg = key 
            else:
                try:
                    cls.arg = cls.arg[key]
                except TypeError:
                    cls.arg = key
            return super().__class_getitem__(key)
    
        def __init_subclass__(cls):
            if Parent.arg is not NO_ARG:
                cls.arg, Parent.arg = Parent.arg, NO_ARG
    
    
    class Child1(Parent[int]):
        pass
    
    
    class Child2(Child1):
        pass
    
    
    class Child3(Parent[T], Generic[T]):
        pass
    
    
    class Child4(Parent[Sequence[T]], Generic[T]):
        pass
    
    
    def get_parent_type_parameter(cls):
        return cls.arg
    
    
    classes = [
        Parent[str],
        Child1,
        Child2,
        Child3[int],
        Child4[float],
    ]
    for cls in classes:
        print(cls, get_parent_type_parameter(cls))
    
    

    Which outputs the following:

    __main__.Parent[str] <class 'str'>
    <class '__main__.Child1'> <class 'int'>
    <class '__main__.Child2'> <class 'int'>
    __main__.Child3[int] <class 'int'>
    __main__.Child4[float] typing.Sequence[float]
    

    This approach requires that every Parent[...] (i.e. __class_getitem__) is followed by an __init_subclass__ because otherwise the former information may be overwritten by a second Parent[...]. For that reasons it won't work with type aliases for example. Consider the following:

    classes = [
        Parent[str],
        Parent[int],
        Parent[float],
    ]
    for cls in classes:
        print(cls, get_parent_type_parameter(cls))
    

    which outputs:

    __main__.Parent[str] <class 'float'>
    __main__.Parent[int] <class 'float'>
    __main__.Parent[float] <class 'float'>