Search code examples
pythonjsonfb-hydra

Load external .json file and merge it with structured hydra config


I am creating some scripts using python and I wanted to utilize Hydra (https://hydra.cc/). I am following the structured config pattern, where I have a config.py and config.yaml in a conf directory, and I am validating my config using dataclasses, for example:

#### config.yaml

visualisation_conf:
  degradation:
    data_path: /measurements
    input_data_formats: .csv

#### config.py

@dataclass
class VisualisationConfig:
    degradation: DegradationConfig

@dataclass
class DegradationConfig:
    data_path: str
    input_data_formats: str

@dataclass
class MainHydraConfig:
    visualisation_conf: VisualisationConfig


#### Usage in script:
import hydra
from conf.config import MainHydraConfig

@hydra.main(version_base=None, config_path="../conf", config_name="config")
def main(base_cfg: MainHydraConfig):
    input_type = base_cfg.visualisation_conf.degradation.input_data_formats
    ...

This worked fine, until I needed to load a json file and add it to my configuration. I have an external json file (called "schema.json" and residing in conf directory) that I would like to load at runtime and merge with my main config. I have tried something like this:

#### config.yaml

dataset_schema:
  schema_path: ${hydra:runtime.cwd}/conf/schema.json
  schema: ...

#### config.py

@dataclass
class JSONSchemaConfig:
    schema_path: str
    schema: dict = field(default_factory=dict)

    def load_schema(self):
        if os.path.exists(self.schema_path):
            with open(self.schema_path, "r") as f:
                self.schema = json.load(f)


@dataclass
class MainHydraConfig:
    json_schema: JSONSchemaConfig = field(default_factory=JSONSchemaConfig)

    def __post__init__(self):
        self.json_schema.load_schema()
        merged_conf = OmegaConf.create(self.json_schema.schema)
        OmegaConf.merge(self, merged_conf)

however, nothing is being loaded at runtime, as the __post_init__ is not being called. Is what I want to achieve even possible in OmegaConf/Hydra? Maybe I should try a different approach for my end goal, that is, loading a json file as dict and merging it with rest of the configuration? I know I can move the logic related to loading this json to the script itself, but I am pretty sure I am just missing some small detail that prevents this from working as expected.

22/7/24 EDIT
As per @Daraan answer, I tried to use 'hydra.experimental.callbacks'. I modified Daraan's answer and added parts from documentation (found here: https://hydra.cc/docs/experimental/callbacks/), now it looks like this:

#### conf/hydra/callbacks/load_json_schema_callback.yaml

# @package _global_
hydra:
  callbacks:
    load_json_schema_callback:
      _target_: conf.config.ParseJsonCallback


#### conf/config.yaml
defaults:
  - /hydra/callbacks:
      - load_json_schema_callback


dataset_schema:
  schema_path: [conf,schema.json]
  schema: ???


#### conf/config.py

@dataclass
class JSONSchemaConfig:
    schema_path: str
    schema: dict = field(default_factory=dict)

    def load_schema(self):
        cwd = hydra.core.hydra_config.HydraConfig.get().runtime.cwd
        _schema_path: str = os.path.join(cwd, *self.schema_path)

        if os.path.exists(_schema_path):
            with open(_schema_path, "r") as f:
                self.schema = json.load(f)

class ParseJsonCallback(Callback):
    def on_job_start(
        self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any
    ) -> None:
        JSONSchemaConfig.load_schema(config.dataset_schema)

When I debug this code, while I am inside the 'load_schema', or even in the 'on_job_start' function (frame), the json is loaded properly. When I leave those frames and go back to the main script, the value in the main config is is '???' again.


Solution

  • The dataclasses are only for schema backing and duck-typing. While it looks like a MainHydraConfig class it is actually a DictConfig without any functionality of the schema classes.

    If you want to call the __post_init__ function you actually need to create on of the real MainHydraConfig classes.

    As a solution, you could write an extra function that takes care of it. I think with the below function the __post_init__ is not necessary.

    @dataclass
    class MainHydraConfig:
       ...
    
       @staticmethod
       def postprocess(settings : omegaconf.DictConfig):
          """
          insert the settings.json_schema : DictConfig as self
          and replaces the settings.json_schema.schema node
          """
          JSONSchemaConfig.load_schema(settings.json_schema) 
    
    # ----
    
    @hydra.main(version_base=None, config_path="../conf", config_name="config")
    def main(base_cfg: MainHydraConfig):
       MainHydraConfig.postprocess(base_cfg)
    
       # obv you could also use the one liner directly, but imo a descriptive name sounds better.
       # JSONSchemaConfig.load_schema(base_cfg.json_schema) 
    
    

    If you really want to use the __post_init__ you could use
    OmegaConf.create(OmegaConf.to_object(base_cfg)) # -> real MainHydraConfig -> DictConfig, as one of many possible ways.

    You could pack everything into a decorator to keep your main function more clean https://hydra.cc/docs/advanced/decorating_main/

    @hydra.main(...)
    @postproces # write decorator that takes care of parsing and inserting the json file
    def main(base_cfg): 
    

    (Edit Hydra only) - use a callback for postprocessing

    To modify the config and have no extra python code the only other way I see is by using callbacks. Note: If you run your script remote or with multirun check that you set up the callback correctly.

    from config import JSONSchemaConfig
    from hydra.experimental.callback import Callback
    
    class ParseJsonCallback(Callback):
       
        def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None:
            """
            Called in both RUN and MULTIRUN modes, once for each Hydra job (before running application code).
            The `task_function` argument is the function
            decorated with `@hydra.main`.
            """
            # modify the config
            JSONSchemaConfig.load_schema(config.json_schema) 
    
    

    To execute the callback register it, eg. by adding

    # conf.yaml
    
    hydra:
        callbacks:
           insert_json:
              _target_ : <modules to>.ParseJsonCallback