Search code examples
pythonpython-typing

What is the proper way to add type hints after loading a YAML file?


I'm adding type hints to my Python code and was wondering what the proper way to type-hint a loaded YAML file since it's a dictionary of any number of dictionaries.

Is there a better way to type-hint returning a loaded YAML file than Dict[str, Dict[str, Any]]?

Here's the function:

def load_yaml(yaml_in: str) -> Dict[str, Dict[str, Any]]:
    return yaml.load(open(yaml_in), Loader=yaml.FullLoader)

Here's an example of the YAML file being loaded:

VariableMap:
    var1: 'time'
    var2: 'param_name'

GlobalVariables:
    limits:
        x-min:
        x-max:
        y-min:
        y-max:

Plots:
    plot1:
        file: 
        x_data: 'date'
        y_data: [{param: 'param1', label: "param1", color: 'red', linestyle: '-'},
                 {param: 'param2', label: "param2", color: 'black', linestyle: '--'}]
        labels:
            title: {label: 'title', fontsize: '9'}
            x-axis: {xlabel: 'x-label', fontsize: '9'}
            y-axis: {ylabel: 'y-label', fontsize: '9'}
        limits:
            x-min: 0
            x-max: 100
            y-min:
            y-max:

Figures:
    fig1:
        shape: [1, 1]
        size: [6, 8]
        plots: ['plot1']

Solution

  • I would suggest looking into the dataclass-wizard library, as it might be helpful for this task. In particular, there exists the YAMLWizard, which overall simplifies working with YAML data.

    First things first, I would suggest defining a variable for the YAML data:

    yaml_string = """
    VariableMap:
        var1: 'time'
        var2: 'param_name'
    
    GlobalVariables:
        limits:
            x-min:
            x-max:
            y-min:
            y-max:
    
    Plots:
        plot1:
            file:
            x_data: 'date'
            y_data: [{param: 'param1', label: "param1", color: 'red', linestyle: '-'},
                     {param: 'param2', label: "param2", color: 'black', linestyle: '--'}]
            labels:
                title: {label: 'title', fontsize: '9'}
                x-axis: {xlabel: 'x-label', fontsize: '9'}
                y-axis: {ylabel: 'y-label', fontsize: '9'}
            limits:
                x-min: 0
                x-max: 100
                y-min:
                y-max:
    
    Figures:
        fig1:
            shape: [1, 1]
            size: [6, 8]
            plots: ['plot1']
    """
    

    Now you can use the CLI utility to generate a very rough dataclass schema, as so:

    import json
    import yaml
    
    import dataclass_wizard.wizard_cli as cli
    
    
    data = yaml.safe_load(yaml_string)
    
    print(cli.PyCodeGenerator(file_contents=json.dumps(data), experimental=True).py_code)
    

    This outputs something like the below. Note that I've went ahead and cleaned up some issues like duplicate classes and "unknown" types (such as for y-min and y-max for example).

    The full dataclass schema:

    from __future__ import annotations
    
    from dataclasses import dataclass
    from typing import Any
    
    from dataclass_wizard import YAMLWizard
    
    
    @dataclass
    class Data(YAMLWizard):
        """
        Data dataclass
    
        """
        variable_map: VariableMap
        global_variables: GlobalVariables
        plots: Plots
        figures: Figures
    
    
    @dataclass
    class VariableMap:
        """
        VariableMap dataclass
    
        """
        var1: str
        var2: str
    
    
    @dataclass
    class GlobalVariables:
        """
        GlobalVariables dataclass
    
        """
        limits: Limits
    
    
    @dataclass
    class Plots:
        """
        Plots dataclass
    
        """
        plot1: Plot
    
    
    @dataclass
    class Plot:
        """
        Plot1 dataclass
    
        """
        file: Any
        x_data: str
        y_data: list[YDatum]
        labels: Labels
        limits: Limits
    
    
    @dataclass
    class YDatum:
        """
        YDatum dataclass
    
        """
        param: str
        label: str
        color: str
        linestyle: str
    
    
    @dataclass
    class Labels:
        """
        Labels dataclass
    
        """
        title: Title
        x_axis: XAxis
        y_axis: YAxis
    
    
    @dataclass
    class Title:
        """
        Title dataclass
    
        """
        label: str
        fontsize: int | str
    
    
    @dataclass
    class XAxis:
        """
        XAxis dataclass
    
        """
        xlabel: str
        fontsize: int | str
    
    
    @dataclass
    class YAxis:
        """
        YAxis dataclass
    
        """
        ylabel: str
        fontsize: int | str
    
    
    @dataclass
    class Limits:
        """
        Limits dataclass
    
        """
        x_min: int
        x_max: int
        y_min: int
        y_max: int
    
    
    @dataclass
    class Figures:
        """
        Figures dataclass
    
        """
        fig1: Fig1
    
    
    @dataclass
    class Fig1:
        """
        Fig1 dataclass
    
        """
        shape: list[int]
        size: list[int]
        plots: list[str]
    

    Now we can load the YAML string to a nested Data object, as below:

    import pprint
    
    data = Data.from_yaml(yaml_string)
    pprint.pprint(data)
    

    Output:

    Data(variable_map=VariableMap(var1='time', var2='param_name'),
         global_variables=GlobalVariables(limits=Limits(x_min=0,
                                                        x_max=0,
                                                        y_min=0,
                                                        y_max=0)),
         plots=Plots(plot1=Plot(file=None,
                                x_data='date',
                                y_data=[YDatum(param='param1',
                                               label='param1',
                                               color='red',
                                               linestyle='-'),
                                        YDatum(param='param2',
                                               label='param2',
                                               color='black',
                                               linestyle='--')],
                                labels=Labels(title=Title(label='title',
                                                          fontsize='9'),
                                              x_axis=XAxis(xlabel='x-label',
                                                           fontsize='9'),
                                              y_axis=YAxis(ylabel='y-label',
                                                           fontsize='9')),
                                limits=Limits(x_min=0, x_max=100, y_min=0, y_max=0))),
         figures=Figures(fig1=Fig1(shape=[1, 1], size=[6, 8], plots=['plot1'])))
    

    Observations

    I noticed that some classes, in particular XLabel and YLabel, contain essentially the same fields, however have slightly different names for the fields in YAML data.

    If desired, such classes could actually be merged into a single class declaration. Then, we can use a key mapping approach such as with json_field() to define alias key names to use when loading the YAML data.

    For example:

    from dataclass_wizard import json_field
    
    @dataclass
    class Axis:
        # noinspection PyDataclass
        label: str = json_field(('xlabel', 'ylabel'))
        fontsize: int | str