Search code examples
fb-hydraomegaconf

fb-hydra: How to implement 2 nested structured configs?


I have 2 sub configs and one master(?) config that having those sub configs. I designed configs like below:

from dataclasses import dataclass, field

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig

from typing import Any, List

@dataclass
class DBConfig:
    host: str = "localhost"
    driver: str = MISSING
    port: int = MISSING


@dataclass
class MySQLConfig(DBConfig):
    driver: str = "mysql"
    port: int = 3306


@dataclass
class PostGreSQLConfig(DBConfig):
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10


@dataclass
class ConnectionConfig:
    target: str = "app.my_class.MyClass"
    params: DBConfig = MISSING
    defaults: List[Any] = field(
        default_factory=lambda: [
            {
                "params": "mysql",      # I'd like to set mysql as a default
            }
        ]
    )



@dataclass
class AConfig:
    name: str = "foo"


@dataclass
class BConfig(AConfig):
    age: int = 10


@dataclass
class CConfig(AConfig):
    age: int = 20


@dataclass
class SomeOtherConfig:
    target: str = "app.my_class.MyClass2"
    params: AConfig = MISSING
    defaults: List[Any] = field(
        default_factory=lambda: [
            {
                "params": "bconfig",   # I'd like to set bconfig as a default
            }
        ]
    )



@dataclass
class Config:
    db_connection: ConnectionConfig = ConnectionConfig()
    some_other: SomeOtherConfig = SomeOtherConfig()


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(cfg.pretty())
    # connection = hydra.utils.instantiate(cfg)
    # print(connection)


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(
        name="config",
        node=Config,
    )
    cs.store(group="params", name="mysql", node=MySQLConfig)
    cs.store(group="params", name="postgresql", node=PostGreSQLConfig)

    cs.store(group="params", name="bconfig", node=BConfig)
    cs.store(group="params", name="cconfig", node=CConfig)

    my_app()

What I expected when I run program without any options:

db_connection:
    target: app.my_class.MyClass
    params:   
        host: localhost
        driver: mysql
        port: 3306   

some_other:
    target: app.my_class.MyClass2
    params:
        name: "foo"
        age: 10

But the result:

db_connection:
    target: app.my_class.MyClass
    params: ???
    defaults:
    - params: mysql
some_other:
    target: app.my_class.MyClass2
    params: ???
    defaults:
    - params: bconfig

Solution

  • First of all, as of Hydra 1.0 - the defaults list is ONLY supported in the primary config. Below are two versions, the first version changes as little as possible in your example, and the second clean things up a bit.

    Example 1:

    from dataclasses import dataclass, field
    
    import hydra
    from hydra.core.config_store import ConfigStore
    from omegaconf import MISSING, DictConfig
    
    from typing import Any, List
    
    
    @dataclass
    class DBConfig:
        host: str = "localhost"
        driver: str = MISSING
        port: int = MISSING
    
    
    @dataclass
    class MySQLConfig(DBConfig):
        driver: str = "mysql"
        port: int = 3306
    
    
    @dataclass
    class PostGreSQLConfig(DBConfig):
        driver: str = "postgresql"
        port: int = 5432
        timeout: int = 10
    
    
    @dataclass
    class ConnectionConfig:
        target: str = "app.my_class.MyClass"
        params: DBConfig = MISSING
    
    
    @dataclass
    class AConfig:
        name: str = "foo"
    
    
    @dataclass
    class BConfig(AConfig):
        age: int = 10
    
    
    @dataclass
    class CConfig(AConfig):
        age: int = 20
    
    
    @dataclass
    class SomeOtherConfig:
        target: str = "app.my_class.MyClass2"
        params: AConfig = MISSING
    
    
    @dataclass
    class Config:
        db_connection: ConnectionConfig = ConnectionConfig()
        some_other: SomeOtherConfig = SomeOtherConfig()
        defaults: List[Any] = field(
            default_factory=lambda: [
                {"db_connection/params": "mysql"},
                {"some_other/params": "bconfig"},
            ]
        )
    
    
    @hydra.main(config_name="config")
    def my_app(cfg: DictConfig) -> None:
        print(cfg.pretty())
    
    
    if __name__ == "__main__":
        cs = ConfigStore.instance()
        cs.store(
            name="config", node=Config,
        )
        cs.store(group="db_connection/params", name="mysql", node=MySQLConfig)
        cs.store(group="db_connection/params", name="postgresql", node=PostGreSQLConfig)
    
        cs.store(group="some_other/params", name="bconfig", node=BConfig)
        cs.store(group="some_other/params", name="cconfig", node=CConfig)
    
        my_app()
    
    

    Example 2:

    from dataclasses import dataclass, field
    
    import hydra
    from hydra.core.config_store import ConfigStore
    from omegaconf import MISSING, DictConfig
    from hydra.types import ObjectConf
    from typing import Any, List
    
    
    @dataclass
    class DBConfig:
        host: str = "localhost"
        driver: str = MISSING
        port: int = MISSING
    
    
    @dataclass
    class MySQLConfig(DBConfig):
        driver: str = "mysql"
        port: int = 3306
    
    
    @dataclass
    class PostGreSQLConfig(DBConfig):
        driver: str = "postgresql"
        port: int = 5432
        timeout: int = 10
    
    
    @dataclass
    class AConfig:
        name: str = "foo"
    
    
    @dataclass
    class BConfig(AConfig):
        age: int = 10
    
    
    @dataclass
    class CConfig(AConfig):
        age: int = 20
    
    
    defaults = [{"db_connection": "mysql"}, {"some_other": "bconfig"}]
    
    
    @dataclass
    class Config:
        db_connection: ObjectConf = MISSING
        some_other: ObjectConf = MISSING
        defaults: List[Any] = field(default_factory=lambda: defaults)
    
    
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)
    cs.store(
        group="db_connection",
        name="mysql",
        node=ObjectConf(target="MySQL", params=MySQLConfig),
    )
    cs.store(
        group="db_connection",
        name="postgresql",
        node=ObjectConf(target="PostgeSQL", params=PostGreSQLConfig),
    )
    cs.store(
        group="some_other",
        name="bconfig",
        node=ObjectConf(target="ClassB", params=BConfig()),
    )
    cs.store(
        group="some_other",
        name="cconfig",
        node=ObjectConf(target="ClassC", params=AConfig()),
    )
    
    
    @hydra.main(config_name="config")
    def my_app(cfg: DictConfig) -> None:
        print(cfg.pretty())
    
    
    if __name__ == "__main__":
        my_app()