Search code examples
yamlfb-hydraomegaconf

Conditional intializations of parameters in hydra


I'm pretty new to hydra and was wondering if the following thing is was possible: I have the parameter num_atom_feats in the model section which I would like to make dependent on the feat_type parameter in the data section. In particular, if I have feat_type: type1 then I would like to have num_atom_feats:22. If instead, I initialize data with feat_type : type2 then I would like to have num_atom_feats:200

model:
  _target_: model.EmbNet_Lightning
  model_name: 'EmbNet'
  num_atom_feats: 22
  dim_target: 128
  loss: 'log_ratio'
  lr: 1e-3
  wd: 5e-6

data:
  _target_: data.DataModule
  feat_type: 'type1'
  batch_size: 64
  data_path: '.'

wandb:
  _target_:  pytorch_lightning.loggers.WandbLogger
  name: embnet_logger
  project: ''

trainer:
  max_epochs: 1000

Solution

  • You can achieve this using OmeagConf's custom resolver feature.

    Here's an example showing how to register a custom resolver that computes model.num_atom_feat based on the value of data.feat_type:

    from omegaconf import OmegaConf
    
    yaml_data = """
    model:
      _target_: model.EmbNet_Lightning
      model_name: 'EmbNet'
      num_atom_feats: ${compute_num_atom_feats:${data.feat_type}}
    
    data:
      _target_: data.DataModule
      feat_type: 'type1'
    """
    
    def compute_num_atom_feats(feat_type: str) -> int:
        if feat_type == "type1":
            return 22
        if feat_type == "type2":
            return 200
        assert False
    
    OmegaConf.register_new_resolver("compute_num_atom_feats", compute_num_atom_feats)
    
    
    cfg = OmegaConf.create(yaml_data)
    
    assert cfg.data.feat_type == 'type1'
    assert cfg.model.num_atom_feats == 22
    cfg.data.feat_type = 'type2'
    assert cfg.model.num_atom_feats == 200
    

    I'd recommend reading through the docs of OmegaConf, which is the backend used by Hydra.

    The compute_num_atom_feats function is invoked lazily when you access cfg.data.num_atom_feats in your python code.

    When using custom resolvers with Hydra, you can call OmegaConf.register_new_resolver either before you invoke your @hydra.main-decorated function, or from within the @hydra.main-decorated function itself. The important thing is that you call OmegaConf.register_new_resolver before you access cfg.data.num_atom_feats.