Search code examples
pythonjsonargumentsfb-hydra

Best way to pass configuration with ranges


I have a plotting function that I use for multiple things. The way I plot curves is the same, but I need to customize the limits and the label depending on the data I am plotting. Right now I define the settings in dictionaries and then read them like

for d in data_keys:
    data = np.load(...) 
    ax.plot(data, label=data_to_label[d])
    ax.xaxis.set_ticks(data_to_xticks[d])
    ax.set_xlim([0, data_to_xlim[d]])

and so on, for other things I need.

The dictionaries are like these

data_to_label = {
    'easy' : 'Alg. 1 (Debug)',
    'hard' : 'Alg. 2',
}

data_to_xlim = {
    'easy' : 500,
    'hard' : 2000,
}

data_to_xticks = {
    'easy' : [0, 250, 500],
    'hard' : np.arange(0, 2001, 500),
}

data_to_ylim = {
    'easy' : [-0.1, 1.05],
    'hard' : [-0.1, 1.05],
}

data_to_yticks = {
    'Easy' : [0, 0.5, 1.],
    'hard' : [0, 0.5, 1.],
}

I have many of these, and I am looking for the best way to save them in config files and load them in my plotting function. I thought about Hydra, YAML, JSON, but none allows to specify np.arange() as parameter. Ideally, when I call python myplot.py I can pass the config file as argument.

I could also import them, but then the import must be read from the string passed to myplot.py.


Solution

  • I could also import them, but then the import must be read from the string passed to myplot.py

    It could be a great idea if you trust the modules to import. You can do that with argparse, importlib and inspect modules:

    myplot.py:

    import argparse
    import importlib
    import inspect
    
    def myplot_function():
        # do stuff here
        print(data_to_label)
        print(data_to_xlim)
        print(data_to_xticks)
        print(data_to_ylim)
        print(data_to_yticks)
    
    
    if __name__ == '__main__':
        # simple cli
        parser = argparse.ArgumentParser(prog='myplot')
        parser.add_argument('-c', '--config')
        args = parser.parse_args()
    
        # inject dictionaries into the global namespace
        cfgmod = importlib.import_module(inspect.getmodulename(args.config))
        dicts = {k: v for k, v in inspect.getmembers(cfgmod)
                 if isinstance(v, dict) and not k.startswith('_')}
        globals().update(**dicts)
    
        myplot_function()
    

    Usage:

    [...]$ python myplot.py -c config.py  # -c whatever/the/path/to/config.py
    {'easy': 'Alg. 1 (Debug)', 'hard': 'Alg. 2'}
    {'easy': 500, 'hard': 2000}
    {'easy': [0, 250, 500], 'hard': array([   0,  500, 1000, 1500, 2000])}  # <- HERE
    {'easy': [-0.1, 1.05], 'hard': [-0.1, 1.05]}
    {'easy': [0, 0.5, 1.0], 'hard': [0, 0.5, 1.0]}
    

    config.py:

    import numpy as np
    
    data_to_label = {
        'easy' : 'Alg. 1 (Debug)',
        'hard' : 'Alg. 2',
    }
    
    data_to_xlim = {
        'easy' : 500,
        'hard' : 2000,
    }
    
    data_to_xticks = {
        'easy' : [0, 250, 500],
        'hard' : np.arange(0, 2001, 500),
    }
    
    data_to_ylim = {
        'easy' : [-0.1, 1.05],
        'hard' : [-0.1, 1.05],
    }
    
    data_to_yticks = {
        'easy' : [0, 0.5, 1.],
        'hard' : [0, 0.5, 1.],
    }