Search code examples
machine-learningyamlpipelinekedro

Kedro - how to pass nested parameters directly to node


kedro recommends storing parameters in conf/base/parameters.yml. Let's assume it looks like this:

step_size: 1
model_params:
    learning_rate: 0.01
    test_data_ratio: 0.2
    num_train_steps: 10000

And now imagine I have some data_engineering pipeline whose nodes.py has a function that looks something like this:

def some_pipeline_step(num_train_steps):
    """
    Takes the parameter `num_train_steps` as argument.
    """
    pass

How would I go about and pass that nested parameters straight to this function in data_engineering/pipeline.py? I unsuccessfully tried:

from kedro.pipeline import Pipeline, node

from .nodes import split_data


def create_pipeline(**kwargs):
    return Pipeline(
        [
            node(
                some_pipeline_step,
                ["params:model_params.num_train_steps"],
                dict(
                    train_x="train_x",
                    train_y="train_y",
                ),
            )
        ]
    )

I know that I could just pass all parameters into the function by using ['parameters'] or just pass all model_params parameters with ['params:model_params'] but it seems unelegant and I feel like there must be a way. Would appreciate any input!


Solution

  • (Disclaimer: I'm part of the Kedro team)

    Thank you for your question. Current version of Kedro, unfortunately, does not support nested parameters. The interim solution would be to use top-level keys inside the node (as you already pointed out) or decorate your node function with some sort of a parameter filter, which is not elegant either.

    Probably the most viable solution would be to customise your ProjectContext (in src/<package_name>/run.py) class by overwriting _get_feed_dict method as follows:

    class ProjectContext(KedroContext):
        # ...
    
    
        def _get_feed_dict(self) -> Dict[str, Any]:
            """Get parameters and return the feed dictionary."""
            params = self.params
            feed_dict = {"parameters": params}
    
            def _add_param_to_feed_dict(param_name, param_value):
                """This recursively adds parameter paths to the `feed_dict`,
                whenever `param_value` is a dictionary itself, so that users can
                specify specific nested parameters in their node inputs.
    
                Example:
    
                    >>> param_name = "a"
                    >>> param_value = {"b": 1}
                    >>> _add_param_to_feed_dict(param_name, param_value)
                    >>> assert feed_dict["params:a"] == {"b": 1}
                    >>> assert feed_dict["params:a.b"] == 1
                """
                key = "params:{}".format(param_name)
                feed_dict[key] = param_value
    
                if isinstance(param_value, dict):
                    for key, val in param_value.items():
                        _add_param_to_feed_dict("{}.{}".format(param_name, key), val)
    
            for param_name, param_value in params.items():
                _add_param_to_feed_dict(param_name, param_value)
    
            return feed_dict
    

    Please also note that this issue has already been addressed on develop and will become available in the next release. The fix uses the approach from the snippet above.