Search code examples
pythondictionaryenums

Python dictionary with enum as key


Let's say I have an enum

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

I wanted to create a ColorDict class that works as a native python dictionary but only takes the Color enum or its corresponding string value as key.

d = ColorDict() # I want to implement a ColorDict class such that ...

d[Color.RED] = 123
d["RED"] = 456  # I want this to override the previous value
d[Color.RED]    # ==> 456
d["foo"] = 789  # I want this to produce an KeyError exception

What's the "pythonic way" of implementing this ColorDict class? Shall I use inheritance (overriding python's native dict) or composition (keep a dict as a member)?


Solution

  • A simple solution would be to slightly modify your Color object and then subclass dict to add a test for the key. I would do something like this:

    class Color(Enum):
        RED = "RED"
        GREEN = "GREEN"
        BLUE = "BLUE"
    
        @classmethod
        def is_color(cls, color):
            if isinstance(color, cls):
                color=color.value
            if not color in cls.__members__:
                return False
            else:
                return True
    
    
    class ColorDict(dict):
        
        def __setitem__(self, k, v):
            if Color.is_color(k):
                super().__setitem__(Color(k), v)
            else:
                raise KeyError(f"Color {k} is not valid")
    
        def __getitem__(self, k):
            if isinstance(k, str):
                k = Color(k.upper())
            return super().__getitem__(k)
    
    d = ColorDict()
    
    d[Color.RED] = 123
    d["RED"] = 456
    d[Color.RED]
    d["foo"] = 789
    

    In the Color class, I have added a test function to return True or False if a color is/isn't in the allowed list. The upper() function puts the string in upper case so it can be compared to the pre-defined values.

    Then I have subclassed the dict object to override the __setitem__ special method to include a test of the value passed, and an override of __getitem__ to convert any key passed as str into the correct Enum. Depending on the specifics of how you want to use the ColorDict class, you may need to override more functions. There's a good explanation of that here: How to properly subclass dict and override __getitem__ & __setitem__