Search code examples
pythonpython-typing

Discriminated union over instance attributes


Suppose I have the following class

class MyClass:
    def __init__(self, ...):
        self.attr1: Union[Attr11, Attr12]
        self.attr2: Literal["val1", "val2", "val3"]
        self.attr3: Attr3
        ...
        self.attrn: Attrn

I am interested in the following type hint:

if self.attr2 == "val1":
    # self.attr1 should be Attr11 here
else:
    # self.attr1 should be Attr12 here

How can I do it?


Solution

  • You can achieve this with the upcoming & experimental TypeForm, which might come in Python 3.14.

    Note: This is currently (Jan 2025) only supported by pyright and with experimental features enabled, e.g. you need the following line in your pyproject.toml file.

    [tool.pyright]
    enableExperimentalFeatures = true
    

    from typing import Union, Literal, TypeGuard, Any, TypeVar, reveal_type
    from typing_extensions import TypeForm, TypeIs
    
    T = TypeVar("T")
    
    def conditional_cast(attr: Any, condition: bool, to_type: TypeForm[T]) -> TypeIs[T]:
       """If condition is True, attr will be narrowed down to the type passed as to_type"""
        return condition
    
    
    class MyClass:
        def __init__(self, *args):
            self.attr1: Union[int, str]
            self.attr2: Literal["val1", "val2", "val3"]
    
            if conditional_cast(self.attr1, self.attr2 == "varl", int):
                reveal_type(self.attr1)  # Type of "self.attr1" is "int"
            else:
                # Note: using TypeGuard instead of TypeIs will not narrow the type, i.e. it stays int | str
                reveal_type(self.attr1)  # Type of "self.attr1" is "str"
    

    Word of warning: You will lose type-warnings if you use Any for the input type of conditional_cast you maybe want to add some checks for runtime safety for testing.