Search code examples
pythonenumsiterationcombinatorics

How to iterate through categorical combinatorics of different Enum classes in Python?


I have a function - main() that takes in an instance of a class InputPermutation as it's only argument. The idea being that differences in the results from a run of the main() are based solely on the differences in the InputPermutation class. I need to iterate through all possible configurations of the InputPermutation class and record the results from main() for each of the runs.

InputPermutation has attributes which are instances of different Enum classes. I need to iterate through all possible Enum class instances for each of the InputPermutation attributes to get all possible configurations.

Here is a simplified model of the problem I am facing.

from enum import Enum

class Colour(Enum):
    GREEN = "green"
    BLUE = "blue"
    RED = "red"

class Country(Enum):
    ENGLAND = "england"
    JAPAN = "japan"
    AUSTRALIA = "australia"

class InputPermutation:
    def __init__(self, colour: Colour, country: Country):
        self.colour = colour
        self.country = country

def main(input_permutation: InputPermutation) -> dict:
    
    colour = input_permutation.colour.value
    country = input_permutation.country.value

    result = {colour: country}

    return result

def iterate() -> dict:

    pass

I need help making this iterate() function, cannot figure out how to make this work...

I would like the function to return a single dictionary, where each key is a "run_index", just a number increasing by 1 for each run of main(). Each value would ideally be a dictionary of the structure below:

{
    1: {"green": "england"},
    2: {"green": "japan"},
    3: {"green": "australia"},
    4: {"blue": "england"},
    # etc...
}

I would like it to scale, so no matter how many different Enum classes, (or new options within the existing Enum classes) are added to the InputPermutation class, the function still will iterate through all the options. I have managed to get this output already without making it scalable in this way.

The issue could be specific to my use of the Enum classes. The reason I have opted for this is because of the drop-down it gives me when I am choosing options to select for. It also standardises the option inputs by locking you into typing them in a specific format before converting them to strings which reduces the likelihood of typing errors if that makes sense.

The problem does have a real world applicability for a model I am making, but I thought this country : colour thing would be easier to work with here...


Solution

  • I think that this is what you are looing for:

    # new imports:
    import typing
    import itertools
    
    def iterate() -> dict:
        # get all options:
        input_args_types = typing.get_type_hints(InputPermutation.__init__)
        all_InputPermutation_combos = list(
            itertools.product(*all_options_per_type)
        )
        return {
            i: {a.name for a in arg}
            for i, arg in enumerate(all_InputPermutation_combos)
        }
    

    Output:

    {0: {'england', 'green'},
     1: {'green', 'japan'},
     2: {'australia', 'green'},
     3: {'blue', 'england'},
     4: {'blue', 'japan'},
     5: {'australia', 'blue'},
     6: {'england', 'red'},
     7: {'japan', 'red'},
     8: {'australia', 'red'}}
    

    Even if you add an Enum:

    class Extra(Enum):
        TEST = "test1"
        TEST2 = "test2"
    
    
    class InputPermutation:
        def __init__(self, colour: Colour, country: Country, extra: Extra):
            self.colour = colour
            self.country = country
    

    Then the output is:

    >>> iterate()
    
    {0: {'england', 'green', 'test1'},
     1: {'england', 'green', 'test2'},
     2: {'green', 'japan', 'test1'},
     3: {'green', 'japan', 'test2'},
     4: {'australia', 'green', 'test1'},
     5: {'australia', 'green', 'test2'},
     6: {'blue', 'england', 'test1'},
     7: {'blue', 'england', 'test2'},
     8: {'blue', 'japan', 'test1'},
     9: {'blue', 'japan', 'test2'},
     10: {'australia', 'blue', 'test1'},
     11: {'australia', 'blue', 'test2'},
     12: {'england', 'red', 'test1'},
     13: {'england', 'red', 'test2'},
     14: {'japan', 'red', 'test1'},
     15: {'japan', 'red', 'test2'},
     16: {'australia', 'red', 'test1'},
     17: {'australia', 'red', 'test2'}}
    

    The function makes use of typing info from the __init__ function. If there is no typing information this function will of course not work. This is one way to use the typing system, check which Enum's are used as input. Itertools is used to create every combination at least once.

    Edit

    If you want to use the result for iterate to create instances, you can adjust the function a little bit:

    def iterate_v1() -> dict:
        # get all options:
        input_args_types = typing.get_type_hints(InputPermutation.__init__)
        all_InputPermutation_combos = list(
            itertools.product(*input_args_types.values())
        )
        input_args_types_reverse = {v: k for k, v in input_args_types.items()}
        {
            i: {f"{input_args_types_reverse[type(a)]}={a.value}" for a in arg}
            for i, arg in enumerate(all_InputPermutation_combos)
        }
    

    The output is then like this:

    {0: {'colour=green', 'country=england'},
     1: {'colour=green', 'country=japan'},
     2: {'colour=green', 'country=australia'},
     3: {'colour=blue', 'country=england'},
     4: {'colour=blue', 'country=japan'},
     5: {'colour=blue', 'country=australia'},
     6: {'colour=red', 'country=england'},
     7: {'colour=red', 'country=japan'},
     8: {'colour=red', 'country=australia'}}
    

    And then:

    InputPermutation(*iterate_v1()[0])
    # or:
    for run in iterate_v1().values():
        InputPermutation(*run)