Search code examples
pythonenums

How can I override the default behavior of `list(MyEnum)`?


I have a custom enum, MyEnum, with some elements that have different names but the same value.

from enum import Enum

class MyEnum(Enum):
    A = 1
    B = 2
    C = 3
    D = 1  # Same value as A

Consequently, list(MyEnum) returns only the names of some of the members (the first name for each value):

>>>list(MyEnum)
[<MyEnum.A: 1>, <MyEnum.B: 2>, <MyEnum.C: 3>]

Apparently, list(MyEnum.__members__) returns all the names:

>>>list(MyEnum.__members__)
['A', 'B', 'C', 'D']

However, if I try to override the __iter__() method for my enum, the override seems to fail:

class MyEnum(Enum):
    A = 1
    B = 2
    C = 3
    D = 1  # Same value as A

    @classmethod # an attempt to override list(MyEnum) that doesn't change anything  
    def __iter__(cls):
        return iter(list(cls.__members__)) 

Apparently list(MyEnum) doesn't ever hit the custom __iter__() (as indicated by, say, adding a print() before returning in our custom __iter__()).

Why is that?

How can I override the default behavior of list(MyEnum) so that I get all the distinct names?


Solution

  • A class method is not the same as an instance method on a metaclass, which is what __iter__ for Enum is. You need to define a new metaclass, which you can use to define a new subclass of Enum that does what you are looking for.

    A caveat: I make no claim that replacing the current behavior of EnumType.__iter__ with your suggestion will be compatible with Enum's current semantics, only that this will make your definition available.

    from enum import EnumType, Enum
    
    class MyEnumType(EnumType):
        def __iter__(self):
            # Must return an Iterator, something with a __next__ method,
            # not an Iterable.
            return iter(list(self.__members__))
    
    class MyEnumBase(Enum, metaclass=MyEnumType):
        pass
    
    class MyEnum(MyEnumBase):
        A = 1
        B = 2
        C = 3
        D = 1
    
    assert list(MyEnum) == ['A', 'B', 'C', 'D']