Search code examples
python-3.xkubernetesairflowjinja2astronomer

Pull xcom from KubernetesPodOperator


I have a DAG that uses KubernetesPodOperator and the task get_train_test_model_task_count in the DAG pushes an xcom variable and I want to use it in the following tasks.

run_this = BashOperator(
    task_id="also_run_this",
    bash_command='echo "ti_key={{ ti.xcom_pull(task_ids=\"get_train_test_model_task_count\", key=\"return_value\")[\"models_count\"] }}"',
)

The above DAG task works and it prints the value as ti_key=24.

I want the same value to be used as a variable,

with TaskGroup("train_test_model_config") as train_test_model_config:
    models_count = "{{ ti.xcom_pull(task_ids=\"get_train_test_model_task_count\", key=\"return_value\")[\"models_count\"] }}"
    print(models_count)
    for task_num in range(0, int(models_count)):
        generate_train_test_model_config_task(task_num)

int(models_count) doesnot work, by throwing the error -

ValueError: invalid literal for int() with base 10: '{{ ti.xcom_pull(task_ids="get_train_test_model_task_count", key="return_value")["models_count"] }}'

And the generate_train_test_model_config_task looks as below:

def generate_train_test_model_config_task(task_num):
    task = KubernetesPodOperator(
        name=f"train_test_model_config_{task_num}",
        image=build_model_image,
        labels=labels,
        cmds=[
            "python3",
            "-m",
            "src.models.train_test_model_config",
            "--tenant=neu",
            f"--model_tag_id={task_num}",
            "--line_plan={{ ti.xcom_pull(key=\"file_name\", task_ids=\"extract_file_name\") }}",
            "--staging_bucket=cs-us-ds"
        ],
        task_id=f"train_test_model_config_{task_num}",
        do_xcom_push=False,
        namespace="airflow",
        service_account_name="airflow-worker",
        get_logs=True,
        startup_timeout_seconds=300,
        container_resources={"request_memory": "29G", "request_cpu": "7000m"},
        node_selector={"cloud.google.com/gke-nodepool": NODE_POOL},
        tolerations=[
            {
                "key": NODE_POOL,
                "operator": "Equal",
                "value": "true",
                "effect": "NoSchedule",
            }
        ],
    )

    return task

Solution

  • The Jinja template pulls from the Airflow context which you only can do within a task, not in top level code.

    Also as a commenter said you will need to use dynamic task mapping to change the DAG structure dynamically, even if you hardcode the model_num or use another way to template it in, those code changes are only picked up every 30s by the scheduler on default and you have no backwards visibility into previous tasks, for example if one day there are only 2 models you can't see model 3 through 8 in the logs from the day before so it gets a bit messy when using a loop like that even if you can get it to work.

    The code below shows the structure that I think will achieve what you want, one model config generated for each task_num. This should work in Airflow 2.3+

    @task
    def generate_list_of_model_nums(**context):
        model_count = context["ti"].xcom_pull(task_ids="get_train_test_model_task_count", key="return_value")["models_count"]
        return list(range(model_count + 1))
    
    
    @task
    def generate_train_test_model_config_task(task_num):
        # code that generates the model config
        return model_config
    
    model_nums=generate_list_of_model_nums()
    generate_train_test_model_config_task.expand(task_num=model_nums)
    

    Notes: I did not test the code above so there might be typos, but this is the general idea, create a list of all the task nums, then use dynamic task mapping to expand over the list.

    If you pull the XCom from the generate_train_test_model_config_task you should get a list of all the model configs :)

    Some resources that might help to adapt this to traditional operators:

    Disclaimer: I work at Astronomer the org who created the guides above :)

    EDIT: thanks for sharing the KPO code! I see you are using the task_num in two parameters, this means you can try to use .expand_kwargs over a list of sets of inputs in form of a dictionaries and then map the KPO directly. Note that this is an Airflow 2.4+ feature.

    Note on the code: I tested the dict generation function but don't have a K8s cluster running rn so I did not test the latter part, I think name and cmd should be expandable 🤞

    @task
    def generate_list_of_param_dicts(**context):
        model_count = context["ti"].xcom_pull(
            task_ids="get_train_test_model_task_count", key="return_value"
        )["models_count"]
    
        param_dicts = []
        for i in range(model_count):
            param_dict = {
                "name": f"train_test_model_config_{i}",
                "cmds": [
                    "python3",
                    "-m",
                    "src.models.train_test_model_config",
                    "--tenant=neu",
                    f"--model_tag_id={i}",
                    '--line_plan={{ ti.xcom_pull(key="file_name", task_ids="extract_file_name") }}',
                    "--staging_bucket=cs-us-ds",
                ],
            }
            param_dicts.append(param_dict)
    
        return param_dicts
    
    
    task = KubernetesPodOperator.partial(
        image=build_model_image,
        labels=labels,
        task_id=f"train_test_model_config",
        do_xcom_push=False,
        namespace="airflow",
        service_account_name="airflow-worker",
        get_logs=True,
        startup_timeout_seconds=300,
        container_resources={"request_memory": "29G", "request_cpu": "7000m"},
        node_selector={"cloud.google.com/gke-nodepool": NODE_POOL},
        tolerations=[
            {
                "key": NODE_POOL,
                "operator": "Equal",
                "value": "true",
                "effect": "NoSchedule",
            }
        ],
    ).expand_kwargs(generate_list_of_param_dicts())