Search code examples
pythonenumspython-typingpyright

How to have Pyright infer type from an enum check?


Can types be associated with enums, so that Pyright can infer the type from an equality check? (Without cast() or isinstance().)

from dataclasses import dataclass
from enum import Enum, auto

class Type(Enum):
    FOO = auto()
    BAR = auto()

@dataclass
class Foo:
    type: Type

@dataclass
class Bar:
    type: Type

item = next(i for i in (Foo(Type.FOO), Bar(Type.BAR)) if i.type == Type.BAR)
reveal_type(item)  # How to have this be `Bar` instead of `Foo | Bar`?

Solution

  • You want a discriminated union (also known as tagged union).

    In a discriminated union, there exists a discriminator (also known as a tag field) which can be used to differentiate the members.

    You currently have an union of Foo and Bar, and you want to discriminate them using the .type attribute. However, this field cannot be the discriminator since it isn't different for each member of the union.

    (playgrounds: Pyright, Mypy)

    for i in (Foo(Type.FOO), Bar(Type.BAR)):
        reveal_type(i)  # Foo | Bar
    
    mischievous_foo = Foo(Type.BAR)  # This is valid
    naughty_bar = Bar(Type.FOO)      # This too
    
    for i in (mischievous_foo, naughty_bar):
        if i.type == Type.FOO:
            reveal_type(i)           # Runtime: Bar, not Foo
    

    If Foo.type can only ever be Type.FOO and Bar.Type be Type.BAR, then it is important that you reflect this in the types:

    (Making type a dataclass field no longer makes sense at this point, but I'm assuming they are only dataclasses for the purpose of this question.)

    @dataclass
    class Foo:
        type: Literal[Type.FOO]
    
    @dataclass
    class Bar:
        type: Literal[Type.BAR]
    

    As Literal[Type.FOO] and Literal[Type.BAR] are disjoint types, i will then be narrowable by checking for the type of .type:

    (playgrounds: Pyright, Mypy)

    for i in (Foo(Type.FOO), Bar(Type.BAR)):
        if i.type == Type.FOO:
            reveal_type(i)           # Foo
    
    Foo(Type.BAR)  # error
    Bar(Type.FOO)  # error
    

    ...even in a generator, yes:

    item = next(i for i in (Foo(Type.FOO), Bar(Type.BAR)) if i.type == Type.BAR)
    reveal_type(item)                # Bar