Search code examples
pythonpython-typingalgebraic-data-typesdiscriminated-union

How can I create a union type that can also be used to instantiate union members in Python?


I'm currently building an algebraic data type to represent the state of a task, as per this question, but want to extend it to make it a little cleaner to use.

Here are the definitions of my states:

@dataclass
class StatusWaiting:
    # Nothing
    pass

@dataclass
class StatusRunning:
    progress: float

@dataclass
class StatusComplete:
    result: int

@dataclass
class StatusFailure:
    reason: str

However, the way that I intend to use these variants has two seemingly incompatible behaviours:

Type annotation should behave as follows:

Status = StatusWaiting | StatusRunning | StatusComplete | StatusFailure

def get_status() -> Status:
    ...

Instantiation should behave as follows:

class Status:
    Waiting = StatusWaiting
    Running = StatusRunning
    Failed = StatusFailed
    Complete = StatusComplete

my_status = Status.Running(0.42)

How can I define the Status type so that I can have it behave as a union when used as a type annotation, and also behave as a collection of the variants for simple initialization?

Status = ???

def get_status() -> Status:
    return Status.Failed("Something has gone horribly wrong")

I've tried using an Enum, but this doesn't appear to allow for instantiation.

class Status(Enum):
    Waiting = StatusWaiting
    Running = StatusRunning
    Complete = StatusComplete
    Failure = StatusFailure


def get_status() -> Status:
    # Mypy: "Status" not callable
    return Status.Complete(42)

Solution

  • This can be achieved by creating a Status abstract base class containing classmethod functions to create each variant of the system.

    Since all of the variants inherit the Status class, they can all be annotated as Status for type-checking.

    The downside of this approach is that all the Status variants do have those classmethods available.

    # We need the __annotations__ feature since our definitions contain
    # circular references
    from __future__ import annotations
    
    from abc import ABC
    from dataclasses import dataclass
    
    
    # Abstract base class with class methods to instantiate each variant
    class Status(ABC):
        @classmethod
        def Waiting(cls) -> StatusWaiting:
            return StatusWaiting()
    
        @classmethod
        def Running(cls, progress: float) -> StatusRunning:
            return StatusRunning(progress)
    
        @classmethod
        def Complete(cls, result: int) -> StatusComplete:
            return StatusComplete(result)
    
        @classmethod
        def Failure(cls, reason: str) -> StatusFailure:
            return StatusFailure(reason)
    
    
    @dataclass
    class StatusWaiting(Status):
        # Nothing
        pass
    
    
    @dataclass
    class StatusRunning(Status):
        progress: float
    
    
    @dataclass
    class StatusComplete(Status):
        result: int
    
    
    @dataclass
    class StatusFailure(Status):
        reason: str
    
    
    # This is type-safe
    def get_status() -> Status:
        return Status.Failure("Something has gone horribly wrong")