Search code examples
python-3.xairflowairflow-2.x

Airflow Branch Operator inside Task Group with Invalid Task IDs


def skip_update_job_pod_name(dag):
    """
    :param dag: Airflow DAG
    :return: Dummy operator to skip update pod name
    """
    return DummyOperator(task_id="skip_update_job_pod_name", dag=dag)


def update_pod_name_branch_operator(dag: DAG, job_id: str):
    """branch operator to update pod name."""
    return BranchPythonOperator(
        dag=dag,
        trigger_rule="all_done",
        task_id="update_pod_name",
        python_callable=update_pod_name_func,
        op_kwargs={"job_id": job_id},
    )


def update_pod_name_func(job_id: Optional[str]) -> str:
    """function for update pod name."""
    return ["update_job_pod_name"] if job_id else ["skip_update_pod_name"]


def update_job_pod_name(dag: DAG, job_id: str, process_name: str) -> MySqlOperator:
    """
    :param dag: Airflow DAG
    :param job_id: Airflow Job ID
    :param process_name: name of the current running process
    :return: MySqlOperator to update Airflow job ID
    """
    return MySqlOperator(
        task_id="update_job_pod_name",
        mysql_conn_id="semantic-search-airflow-sdk",
        autocommit=True,
        sql=[
            f"""
                INSERT INTO airflow.Pod (job_id, pod_name, task_name)
                SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
                WHERE NOT EXISTS (
                    SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
                ) LIMIT 1;
            """
        ],
        task_concurrency=1,
        dag=dag,
        trigger_rule="all_done",
    )

I am putting branch operator inside a task group but I got this error f"Branch callable must return valid task_ids. Invalid tasks found: {invalid_task_ids}" airflow.exceptions.AirflowException: Branch callable must return valid task_ids. Invalid tasks found: {'update_job_pod_name'} What can be causing this?


Solution

  • The reason is that task inside a group get a task_id with convention of the TaskGroup.

    for example, if we call the group "tg1" and the task_id = "update_pod_name" then the name eventually of the task in the dag is tg1.update_pod_name.

    The best way to solve it is to use the name of the variable that get the operator assignment

    def update_pod_name_func(job_id: Optional[str]) -> str:
        """function for update pod name."""
        return update_job_pod_name.task_id if job_id else skip_update_pod_name.task_id