Search code examples
pythonfb-hydra

Use a parameter multiple times in hydra config file


I am currently trying to replace the usage of argparse with hydra files to set the hyperparameters of a deep learning neural network.

I succeeded in using a config.yaml file linked to a hydra main file to run a training and a prediction.

However, I am loading three .py files for the process and there are some common parameters between them (file path, number of labels for example).

Is there a way of using a parameter several times in a config.yaml file supported by hydra ?

Main file structure:

import time
from omegaconf import DictConfig, OmegaConf
from segmentation_monai import split, train, predict
import hydra
import warnings
from segmentation_monai import split
warnings.filterwarnings('ignore', category=UserWarning)

@hydra.main(config_path='.', config_name="config_bis")

def my_param(cfg:DictConfig) -> None:

    if cfg.split.run: split.main(cfg.split)
    if cfg.train.run: train.main(cfg.train)
    if cfg.predict.run: predict.main(cfg.predict)

if __name__ == "__main__":
    my_param()

Config file:

split:
  run: False
#  mandatory:
  root_path: D:/breast_seg/db_test
  data_dim: 3
  train_dim: 3
  [...]

train:
  run: False
# mandatory:
  root_path: D:/breast_seg/db_test
  data_dim: 3
  train_dim: 3
  [...]

predict:
  run: True
# mandatory:
  root_path: D:/breast_seg/db_test
  data_dim: 3
  train_dim: 3
  [...]

Thank you.


Solution

  • You can use the same parameter multiple in the config using OmegaConf interpolations.

    
    # Extracting to an individual config node. 
    # You can also reuse one of your own nodes for this.
    data:
      room_path: D:/breast_seg/db_test
      data_dim: 3
      train_dim: 3
    
    split:
      run: False
    #  mandatory:
      root_path: ${data.root_path}
      data_dim: ${data.data_dim}
      train_dim: ${data.train_dim}
      [...]
    
    train:
      run: False
    # mandatory:
      root_path: ${data.root_path}
      data_dim: ${data.data_dim}
      train_dim: ${data.train_dim}
      [...]
    
    predict:
      run: True
    # mandatory:
      root_path: ${data.root_path}
      data_dim: ${data.data_dim}
      train_dim: ${data.train_dim}
      [...]