Search code examples
pythonenumsoverriding

How to implement Enum with mutable members, comparable and hashable by their index


I am writing a class to represent sequential stages of an industrial process, composed of ADMISSION, PROCESSING, QC and DELIVERY stages.

Each stage has a unique, progressive sequence number, a mnemonic name and a field keeping track of the number of instances going through it:

@dataclass
class Stage:
    seq_id:       int
    name:         str
    n_instances:  int

Since stages are well-known and not supposed to change during execution, I decided to gather them in an Enum.
I need said enum to have the following requirements:

enum members need to:

  1. be subclasses of Stage , in order to avoid accessing their value and making them easier to use (akin to IntEnum or StrEnum). In particular:

    1. be comparable by their seq_id (e.g. Stages.DELIVERY > Stages.PROCESSING is true)

    2. be usable as dictionary keys

    3. use name as their __str__ representation.

  2. have immutable, sequential seq_ids from 0 to n based on their declaration order

  3. name is specified at member declaration. Using auto() results in the lower-cased member name

I managed to address point 2 and 3 in my implementation (see below).
How can I implement point 1 (and its subpoints)?

Final implementation

(fixed thanks to @EthanFurman's answer)

The idea is to use Stage instances as the enum members, and use their seq_ids as member values.

  • seq_id is no longer in the Stage class as it is stored directly in the enum member's _value_; I made this choice because I deem seq_ids not to have meaning outside the enum.

  • seq_id cannot be specified at member declaration. Instead, it is automatically generated based on declaration order (thanks to auto-numbering enum pattern). This prevents invalid seq_ids from being specified

  • dunder methods __lt__ , __eq__, __str__ and __hash__ have been implemented
    inside the enum rather than the Stage class because their behavior is tied to the use of the enum members

  • Hashing is based on the member's _value_, which is supposedly
    immutable
    (hashing Stage would've been tricky due to its
    mutability)

@dataclass
class Stage:
    label:        str
    n_instances:  int = 0

@total_ordering         
class Stages(Stage, Enum): #req no. 1 (members are subclasses of stage)

    #req no. 1.1
    def __lt__(self, other):
        if self.__class__ is other.__class__:
            return self.value < other.value
        return NotImplemented
    def __eq__(self, other):
        if self.__class__ is other.__class__:
            return self.value == other.value
        return NotImplemented
    
    #req no. 1.2
    def __hash__(self):
        return hash(self.value)

    #req no. 1.3
    def __str__(self):
        return self.label

    #req no. 2
    def __new__(cls, label):
        #auto numbering enum pattern for seq_ids
        value = len(cls.__members__) + 1

        obj = Stage.__new__(cls)

        #enum value set to seq_id
        obj._value_ = value

        return obj
    
    #req no. 3
    @override
    def _generate_next_value_(name, start, count, last_values):
        return name.lower()

    ADMISSION = auto()
    PROCESSING = auto()
    QC = auto()
    DELIVERY = auto()

#ordering test     
assert(Stages.PROCESSING.__lt__(Stages.DELIVERY))

#dictionary key test
stage_to_color = {
    Stages.ADMISSION :  "#B10156",
    Stages.PROCESSING : "#F4B704",
    Stages.QC :         "#FD0002",
    Stages.DELIVERY :   "#7FB857"
}

assert(stage_to_color[Stages.QC] == "#FD0002")

Solution

  • Both dataclass and Enum do a lot of work to make things simple for the user -- when you start extending and/or combining them you need to be careful.

    Working code:

    from dataclasses import dataclass
    from enum import Enum, auto, unique
    from functools import total_ordering
    from typing import override
    
    @dataclass
    class Stage:
    
        seq_id:       int
        label:        str                         # name and value are reserved by Enum
        n_instances:  int = 0
    
        # req 1.2
        def __hash__(self):
            return hash(self.seq_id)
    
    
    @total_ordering             #req no. 1 (members are subclasses of stage)
    class Stages(Stage, Enum):
    
        #req no. 1.1 (ordering)
        def __lt__(self, other):
            if self.__class__ is other.__class__:
                return self.seq_id < other.seq_id
            return NotImplemented
    
        def __eq__(self, other):
            if self.__class__ is other.__class__:
                return self.seq_id == other.seq_id
            return NotImplemented
    
        #req no. 1.3
        def __str__(self):
            return self.label
    
        #req nos. 2 & 3
        @override
        def _generate_next_value_(name, start, count, last_values):
            return count, name.lower()
    
        ADMISSION = auto()
        PROCESSING = auto()
        QC = auto()
        DELIVERY = auto()