Search code examples
pythongenericspython-typing

Dataclasses: Matching Generic TypeVar names to attributes in the origin class


Say I have a Generic dataclass like the following:

from dataclasses import dataclass
from typing import TypeVar, Generic

T = TypeVar('T')
U = TypeVar('U')


@dataclass
class Class(Generic[T, U]):
    foo: U
    bar: T


IntStrClass = Class[int, str]

When we read the code you can see that for IntStrClass:

  • the T lines up with int, which makes the type of bar an int.
  • the U lines up with str, which makes the type of foo a str.

But how can I figure this out progamatically?

I've been playing around with the typing module, but can't see from the outputs how I would match them up. What I have is:

from typing import get_type_hints, get_origin, get_args

print("Class field types:", get_type_hints(get_origin(IntStrClass)))
print("Class generic args:", get_args(IntStrClass))
Class field types: {'foo': ~U, 'bar': ~T}
Class generic args: (<class 'int'>, <class 'str'>)

What I'm missing here is from the definition of Class, to determine that T -> int and U -> str. If I had this information, then I could infer the proper types of foo and bar.

Thanks in advance!


Solution

  • How about this?

    [Has been significantly edited following a conversation in the comments.]

    from dataclasses import dataclass
    from typing import TypeVar, Generic, get_type_hints, get_args, get_origin
    
    T = TypeVar('T')
    U = TypeVar('U')
    
    
    @dataclass
    class Class(Generic[T, U]):
        foo: U
        spam: str
        bar: T
        baz: int
    
    
    IntStrClass = Class[int, str]
    
    def get_annotations(generic_subclass):
        generic_origin = get_origin(generic_subclass)
        annotations_map = get_type_hints(generic_origin)
        generic_args = get_args(generic_subclass)
    
        try:
            generic_params = generic_origin.__parameters__
        except AttributeError as err:
            raise AttributeError(
                f"{origin} has no attribute '__parameters__'. "
                "The likely cause of this is that the typing module's "
                "API for the Generic class has changed "
                "since this function was written."
                ) from err
    
        type_var_map = dict(zip(generic_params, generic_args))
        
        for field, annotation in annotations_map.items():
            if isinstance(annotation, TypeVar):
                annotations_map[field] = type_var_map[annotation]
                
        return annotations_map
    
    print("Resolved attributes:", get_annotations(IntStrClass))
    
    Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}