Search code examples
pythontypestype-hinting

Type hint class attribute depending on __init__ argument value


I have the following class:

InputType = Literal[
    "TEXT",
    "IMAGE",
    "VIDEO",
]

class MyClass:
    def __init__(self, input_type: InputType) -> None:
        if input_type == "VIDEO":
            self.jobs = VideoClass()
        elif input_type == "IMAGE":
            self.jobs = ImageClass()
        elif input_type == "TEXT":
            self.jobs = TextClass()

I want to type hint the class attribute "job" with the corresponding class (VideoClass, ImageClass or TextClass).

Currently, pylance shows "jobs" as being one of the three classes.

I tried to add some @overload on an additional method _init_job_attribute but it didn't help.


Solution

  • You can use overloads to express your intent: just make the class generic in the job type and overload __new__ or __init__ to set that type variable according to the input parameter.

    from __future__ import annotations
    from typing import Generic, TypeVar, Literal, cast, overload
    
    InputType = Literal[
        "TEXT",
        "IMAGE",
        "VIDEO",
    ]
    
    class TextClass: ...
    class ImageClass: ...
    class VideoClass: ...
    
    _T = TypeVar('_T', TextClass, ImageClass, VideoClass)
    
    
    class MyClass(Generic[_T]):
        jobs: _T
        
        @overload
        def __init__(self: MyClass[TextClass], input_type: Literal["TEXT"]) -> None: ...
        @overload
        def __init__(self: MyClass[ImageClass], input_type: Literal["IMAGE"]) -> None: ...
        @overload
        def __init__(self: MyClass[VideoClass], input_type: Literal["VIDEO"]) -> None: ...
        
        def __init__(self, input_type: InputType) -> None:
            if input_type == "VIDEO":
                self.jobs = cast(_T, VideoClass())  # type: ignore[redundant-cast]
            elif input_type == "IMAGE":
                self.jobs = cast(_T, ImageClass())  # type: ignore[redundant-cast]
            elif input_type == "TEXT":
                self.jobs = cast(_T, TextClass())  # type: ignore[redundant-cast]
                
    reveal_type(MyClass("VIDEO").jobs)  # N: Revealed type is "__main__.VideoClass"
    reveal_type(MyClass("IMAGE").jobs)  # N: Revealed type is "__main__.ImageClass"
    reveal_type(MyClass("TEXT").jobs)  # N: Revealed type is "__main__.TextClass"
    

    Ignore comments and casts above are used to inform mypy that you're sure that class instance you're assigning is indeed exactly of _T type and not one of its options. It's necessary, because mypy does not take overloads into account while checking function body. Any external caller will be able to instantiate MyClass only with three options you allowed (if he's also running some type checker, of course).

    Here's a playground.