kedro
recommends storing parameters in conf/base/parameters.yml
. Let's assume it looks like this:
step_size: 1
model_params:
learning_rate: 0.01
test_data_ratio: 0.2
num_train_steps: 10000
And now imagine I have some data_engineering
pipeline whose nodes.py
has a function that looks something like this:
def some_pipeline_step(num_train_steps):
"""
Takes the parameter `num_train_steps` as argument.
"""
pass
How would I go about and pass that nested parameters straight to this function in data_engineering/pipeline.py
? I unsuccessfully tried:
from kedro.pipeline import Pipeline, node
from .nodes import split_data
def create_pipeline(**kwargs):
return Pipeline(
[
node(
some_pipeline_step,
["params:model_params.num_train_steps"],
dict(
train_x="train_x",
train_y="train_y",
),
)
]
)
I know that I could just pass all parameters into the function by using ['parameters']
or just pass all model_params
parameters with ['params:model_params']
but it seems unelegant and I feel like there must be a way. Would appreciate any input!
(Disclaimer: I'm part of the Kedro team)
Thank you for your question. Current version of Kedro, unfortunately, does not support nested parameters. The interim solution would be to use top-level keys inside the node (as you already pointed out) or decorate your node function with some sort of a parameter filter, which is not elegant either.
Probably the most viable solution would be to customise your ProjectContext
(in src/<package_name>/run.py
) class by overwriting _get_feed_dict
method as follows:
class ProjectContext(KedroContext):
# ...
def _get_feed_dict(self) -> Dict[str, Any]:
"""Get parameters and return the feed dictionary."""
params = self.params
feed_dict = {"parameters": params}
def _add_param_to_feed_dict(param_name, param_value):
"""This recursively adds parameter paths to the `feed_dict`,
whenever `param_value` is a dictionary itself, so that users can
specify specific nested parameters in their node inputs.
Example:
>>> param_name = "a"
>>> param_value = {"b": 1}
>>> _add_param_to_feed_dict(param_name, param_value)
>>> assert feed_dict["params:a"] == {"b": 1}
>>> assert feed_dict["params:a.b"] == 1
"""
key = "params:{}".format(param_name)
feed_dict[key] = param_value
if isinstance(param_value, dict):
for key, val in param_value.items():
_add_param_to_feed_dict("{}.{}".format(param_name, key), val)
for param_name, param_value in params.items():
_add_param_to_feed_dict(param_name, param_value)
return feed_dict
Please also note that this issue has already been addressed on develop and will become available in the next release. The fix uses the approach from the snippet above.