Search code examples
pythonpython-typing

How to narrow types in python with Enum


In python, consider the following example

from enum import StrEnum
from typing import Literal, overload


class A(StrEnum):
    X = "X"
    Y = "Y"


class X: ...
class Y: ...


@overload
def enum_to_cls(var: Literal[A.X]) -> type[X]: ...

@overload
def enum_to_cls(var: Literal[A.Y]) -> type[Y]: ...

def enum_to_cls(var: A) -> type[X] | type[Y]:
    match var:
        case A.X:
            return X
        case A.Y:
            return Y
        case _:
            raise ValueError(f"Unknown enum value: {var}")

When I attempt to call enum_to_cls, I get a type error, with the following case:

selected_enum = random.choice([x for x in A])
enum_to_cls(selected_enum)

# Argument of type "A" cannot be assigned to parameter "var" of type "Literal[A.Y]" in
# function "enum_to_cls" 
# "A" is not assignable to type "Literal[A.Y]" [reportArgumentType]

I understand the error and it makes sense, but I wanted to know, if there is any way to avoid this error. I know I can avoid this error, creating a branch for each enum case but then I am back to square one of why I wanted to created the function enum_to_cls.


Solution

  • The simple workaround is to include the implementation's signature as a third overload:

    (playgrounds: Pyright, Mypy)

    @overload
    def enum_to_cls(var: Literal[A.X]) -> type[X]: ...
    
    @overload
    def enum_to_cls(var: Literal[A.Y]) -> type[Y]: ...
    
    @overload
    def enum_to_cls(var: A) -> type[X] | type[Y]: ...
    
    def enum_to_cls(var: A) -> type[X] | type[Y]:
        # Implementation goes here
    
    selected_enum = random.choice([x for x in A])
    reveal_type(enum_to_cls(selected_enum))  # type[X] | type[Y]
    

    This is a workaround rather than a definitive solution, because the enum A is supposed to be expanded to the union of its members during overload evaluation. Indeed, if the argument were declared to be of the type Literal[A.X, A.Y], there would be no error:

    (playgrounds: Pyright, Mypy)

    @overload
    def enum_to_cls(var: Literal[A.X]) -> type[X]: ...
    
    @overload
    def enum_to_cls(var: Literal[A.Y]) -> type[Y]: ...
    
    def enum_to_cls(var: A) -> type[X] | type[Y]:
        # Implementation goes here
    
    selected_enum: Literal[A.X, A.Y] = ...
    reveal_type(enum_to_cls(selected_enum))  # type[X] | type[Y]
    

    Type expansion during overload evaluation is part of a recent proposed addition to the specification. Pyright has yet to conform to this, but it will once the proposal is accepted.