Search code examples

How to narrow types in python with Enum

In python, consider the following example

from enum import StrEnum
from typing import Literal, overload

class A(StrEnum):
    X = "X"
    Y = "Y"

class X: ...
class Y: ...

def enum_to_cls(var: Literal[A.X]) -> type[X]: ...

def enum_to_cls(var: Literal[A.Y]) -> type[Y]: ...

def enum_to_cls(var: A) -> type[X] | type[Y]:
    match var:
        case A.X:
            return X
        case A.Y:
            return Y
        case _:
            raise ValueError(f"Unknown enum value: {var}")

When I attempt to call enum_to_cls, I get a type error, with the following case:

selected_enum = random.choice([x for x in A])

# Argument of type "A" cannot be assigned to parameter "var" of type "Literal[A.Y]" in
# function "enum_to_cls" 
# "A" is not assignable to type "Literal[A.Y]" [reportArgumentType]

I understand the error and it makes sense, but I wanted to know, if there is any way to avoid this error. I know I can avoid this error, creating a branch for each enum case but then I am back to square one of why I wanted to created the function enum_to_cls.


  • The simple workaround is to include the implementation's signature as a third overload:

    (playgrounds: Pyright, Mypy)

    def enum_to_cls(var: Literal[A.X]) -> type[X]: ...
    def enum_to_cls(var: Literal[A.Y]) -> type[Y]: ...
    def enum_to_cls(var: A) -> type[X] | type[Y]: ...
    def enum_to_cls(var: A) -> type[X] | type[Y]:
        # Implementation goes here
    selected_enum = random.choice([x for x in A])
    reveal_type(enum_to_cls(selected_enum))  # type[X] | type[Y]

    This is a workaround rather than a definitive solution, because the enum A is supposed to be expanded to the union of its members during overload evaluation. Indeed, if the argument were declared to be of the type Literal[A.X, A.Y], there would be no error:

    (playgrounds: Pyright, Mypy)

    def enum_to_cls(var: Literal[A.X]) -> type[X]: ...
    def enum_to_cls(var: Literal[A.Y]) -> type[Y]: ...
    def enum_to_cls(var: A) -> type[X] | type[Y]:
        # Implementation goes here
    selected_enum: Literal[A.X, A.Y] = ...
    reveal_type(enum_to_cls(selected_enum))  # type[X] | type[Y]

    Type expansion during overload evaluation is part of a recent proposed addition to the specification. Pyright has yet to conform to this, but it will once the proposal is accepted.