Search code examples
pythonfb-hydra

Restrict possible values in hydra structured configs


I try to adopt my app for hydra framework. I use structured config schema and I want to restrict possible values for some fields. Is there any way to do that?

Here is my code:

my_app.py:

import hydra


@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: str = "foo"


@hydra.main(config_path="configs", config_name="config")
def main(cfg: Config):
    print(cfg)


if __name__ == "__main__":
    main()

configs/config.yaml:

# value is incorrect.
# I need hydra to throw an exception in this case
some_value: "barrr"

Solution

  • A few options:

    1) If your acceptable values are enumerable, use an Enum type:

    from enum import Enum
    from dataclasses import dataclass
    
    class SomeValue(Enum):
        foo = 1
        bar = 2
    
    @dataclass
    class Config:
        # possible values are 'foo' and 'bar'
        some_value: SomeValue = SomeValue.foo
    

    If no fancy logic is needed to validate some_value, this is the solution I would recommend.

    2) If you are using yaml files, you can use OmegaConf to register a custom resolver:

    # my_python_file.py
    from omegaconf import OmegaConf
    
    def check_some_value(value: str) -> str:
        assert value in ("foo", "bar")
        return value
    
    OmegaConf.register_new_resolver("check_foo_bar", check_some_value)
    
    @hydra.main(...)
    ...
    
    if __name__ == "__main__":
        main()
    
    # my_yaml_file.yaml
    some_value: ${check_foo_bar:foo}
    

    When you access cfg.some_value in your python code, an AssertionError will be raised if the value does not agree with the check_some_value function.

    3) After config composition is completed, you can call OmegaConf.to_object to create an instance of your dataclass. This means that the dataclass's __post_init__ function will get called.

    import hydra
    from dataclasses import dataclass
    from omegaconf import DictConfig, OmegaConf
    
    @dataclass
    class Config:
        # possible values are 'foo' and 'bar'
        some_value: str = "foo"
    
        def __post_init__(self) -> None:
            assert self.some_value in ("foo", "bar")
    
    @hydra.main(config_path="configs", config_name="config")
    def main(dict_cfg: DictConfg):
        cfg: Config = OmegaConf.to_object(dict_cfg)
        print(cfg)
    
    if __name__ == "__main__":
        main()