Say I have a Generic dataclass like the following:
from dataclasses import dataclass
from typing import TypeVar, Generic
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class Class(Generic[T, U]):
foo: U
bar: T
IntStrClass = Class[int, str]
When we read the code you can see that for IntStrClass
:
T
lines up with int
, which makes the type of bar
an int
.U
lines up with str
, which makes the type of foo
a str
.But how can I figure this out progamatically?
I've been playing around with the typing
module, but can't see from the outputs how I would match them up. What I have is:
from typing import get_type_hints, get_origin, get_args
print("Class field types:", get_type_hints(get_origin(IntStrClass)))
print("Class generic args:", get_args(IntStrClass))
Class field types: {'foo': ~U, 'bar': ~T}
Class generic args: (<class 'int'>, <class 'str'>)
What I'm missing here is from the definition of Class
, to determine that T -> int
and U -> str
. If I had this information, then I could infer the proper types of foo
and bar
.
Thanks in advance!
How about this?
[Has been significantly edited following a conversation in the comments.]
from dataclasses import dataclass
from typing import TypeVar, Generic, get_type_hints, get_args, get_origin
T = TypeVar('T')
U = TypeVar('U')
@dataclass
class Class(Generic[T, U]):
foo: U
spam: str
bar: T
baz: int
IntStrClass = Class[int, str]
def get_annotations(generic_subclass):
generic_origin = get_origin(generic_subclass)
annotations_map = get_type_hints(generic_origin)
generic_args = get_args(generic_subclass)
try:
generic_params = generic_origin.__parameters__
except AttributeError as err:
raise AttributeError(
f"{origin} has no attribute '__parameters__'. "
"The likely cause of this is that the typing module's "
"API for the Generic class has changed "
"since this function was written."
) from err
type_var_map = dict(zip(generic_params, generic_args))
for field, annotation in annotations_map.items():
if isinstance(annotation, TypeVar):
annotations_map[field] = type_var_map[annotation]
return annotations_map
print("Resolved attributes:", get_annotations(IntStrClass))
Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}