Search code examples
fb-hydra

How to load Hydra parameters from previous jobs (without having to use argparse and the compose API)?


I'm using Hydra for training machine learning models. It's great for doing complex commands like python train.py data=MNIST batch_size=64 loss=l2. However, if I want to then run the trained model with the same parameters, I have to do something like python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml. I then use argparse to load in the previous yaml and use the compose API to initialize the Hydra environment. The path to the trained model is inferred from the path to Hydra's .yaml file. If I want to modify one of the parameters, I have to add additional argparse parameters and run something like python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml --batch_size 128. The code then manually overrides any Hydra parameters with those that were specified on the command line.

What's the right way of doing this?

My current code looks something like the following:

train.py:

import hydra

@hydra.main(config_name="config", config_path="conf")
def main(cfg):
    # [training code using cfg.data, cfg.batch_size, cfg.loss etc.]
    # [code outputs model checkpoint to job folder generated by Hydra]
main()

reconstruct.py:

import argparse
import os
from hydra.experimental import initialize, compose

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('hydra_config')
    parser.add_argument('--batch_size', type=int)
    # [other flags and parameters I may need to override]
    args = parser.parse_args()

    # Create the Hydra environment.
    initialize()
    cfg = compose(config_name=args.hydra_config)

    # Since checkpoints are stored next to the .hydra, we manually generate the path.
    checkpoint_dir = os.path.dirname(os.path.dirname(args.hydra_config))

    # Manually override any parameters which can be changed on the command line.
    batch_size = args.batch_size if args.batch_size else cfg.data.batch_size

    # [code which uses checkpoint_dir to load the model]
    # [code which uses both batch_size and params in cfg to set up the data etc.]

This is my first time posting, so let me know if I should clarify anything.


Solution

  • If you want to load the previous config as is and not change it, use OmegaConf.load(file_path).

    If you want to re-compose the config (and it sounds like you do, because you added that you want override things), I recommend that you use the Compose API and pass in parameters from the overrides file in the job output directory (next to the stored config.yaml), but concatenate the current run parameters.

    This script seems to be doing the job:

    import os
    from dataclasses import dataclass
    from os.path import join
    from typing import Optional
    
    from omegaconf import OmegaConf
    
    import hydra
    from hydra import compose
    from hydra.core.config_store import ConfigStore
    from hydra.core.hydra_config import HydraConfig
    from hydra.utils import to_absolute_path
    
    # You can also use a yaml config file instead of this Structured Config
    @dataclass
    class Config:
        load_checkpoint: Optional[str] = None
        batch_size: int = 16
        loss: str = "l2"
    
    
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)
    
    
    @hydra.main(config_path=".", config_name="config")
    def my_app(cfg: Config) -> None:
    
        if cfg.load_checkpoint is not None:
            output_dir = to_absolute_path(cfg.load_checkpoint)
            original_overrides = OmegaConf.load(join(output_dir, ".hydra/overrides.yaml"))
            current_overrides = HydraConfig.get().overrides.task
    
            hydra_config = OmegaConf.load(join(output_dir, ".hydra/hydra.yaml"))
            # getting the config name from the previous job.
            config_name = hydra_config.hydra.job.config_name
            # concatenating the original overrides with the current overrides
            overrides = original_overrides + current_overrides
            # compose a new config from scratch
            cfg = compose(config_name, overrides=overrides)
    
        # train
        print("Running in ", os.getcwd())
        print(OmegaConf.to_yaml(cfg))
    
    
    if __name__ == "__main__":
        my_app()
    
    ~/tmp$ python train.py 
    Running in  /home/omry/tmp/outputs/2021-04-19/21-23-13
    load_checkpoint: null
    batch_size: 16
    loss: l2
    
    ~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13
    Running in  /home/omry/tmp/outputs/2021-04-19/21-23-22
    load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
    batch_size: 16
    loss: l2
    
    ~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13 batch_size=32
    Running in  /home/omry/tmp/outputs/2021-04-19/21-23-28
    load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
    batch_size: 32
    loss: l2