Search code examples
python-3.xairflowairflow-2.x

How to only run certain operator when dag conf value exist



def skip_update_job_pod_name(dag):
    """
    :param dag: Airflow DAG
    :return: Dummy operator to skip update pod name
    """
    return DummyOperator(task_id="skip_update_job_pod_name", dag=dag)


def update_pod_name_branch_operator(dag: DAG, job_id: str):
    """branch operator to update pod name."""
    return BranchPythonOperator(
        dag=dag,
        trigger_rule="all_done",
        task_id="update_pod_name",
        python_callable=update_pod_name_func,
        op_kwargs={"job_id": job_id},
    )


def update_pod_name_func(job_id: Optional[str]) -> str:
    """function for update pod name."""
    return "update_job_pod_name" if job_id else "skip_update_pod_name"


def update_job_pod_name(dag: DAG, job_id: str, process_name: str) -> MySqlOperator:
    """
    :param dag: Airflow DAG
    :param job_id: Airflow Job ID
    :param process_name: name of the current running process
    :return: MySqlOperator to update Airflow job ID
    """
    return MySqlOperator(
        task_id="update_job_pod_name",
        mysql_conn_id="semantic-search-airflow-sdk",
        autocommit=True,
        sql=[
            f"""
                INSERT INTO airflow.Pod (job_id, pod_name, task_name)
                SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
                WHERE NOT EXISTS (
                    SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
                ) LIMIT 1;
            """
        ],
        task_concurrency=1,
        dag=dag,
        trigger_rule="all_done",
    )

def create_k8s_pod_operator_without_volume(dag: DAG,
                                           job_id: int,
                                           ....varaible) -> TaskGroup:
    """
    Create task group for k8 operator without volume
    """
    with TaskGroup(group_id="k8s_pod_operator_without_volume", dag=dag) as eks_without_volume_group:
        emit_pod_name_branch = update_pod_name_branch_operator(dag=dag, job_id=job_id)
        update_pod_name = update_job_pod_name(dag=dag, job_id=job_id, process_name=process_name)
        skip_update_pod_name = skip_update_job_pod_name(dag=dag)
       emit_pod_name_branch >> [update_pod_name, skip_update_pod_name]
    return eks_without_volume_group

I update the code based on the comment, I am curious how does the taskgroup work with branch operator I will get this when I try to do this airflow.exceptions.AirflowException: Branch callable must return valid task_ids. Invalid tasks found: {'update_job_pod_name'}


Solution

  • You can use BranchPythonOperator that get the value and return which the name of task to run in any condition.

    def choose_job_func(job_id):
        if job_id:
            return "update_pod_name_rds"
    
    
    choose_update_job =BranchPythonOperator(task_id="choose_update_job", python_callable=choose_job_func,
                         op_kwargs={"job_id": "{{ params.job_id }}"})
    

    or, in task flow api it would look like this :

    @task.branch
    def choose_update_job(job_id):
        if job_id:
            return "update_pod_name_rds"
    

    Full Dag Example :

    with DAG(
    dag_id="test_dag",
    start_date=datetime(2022, 1, 1),
    schedule_interval=None,
    render_template_as_native_obj=True,
    params={
        "job_id": Param(default=None, type=["null", "string"])
    },
    tags=["test"],) as dag:
    
    def update_job_pod_name(job_id: str, process_name: str):
        return MySqlOperator(
            task_id="update_pod_name_rds",
            mysql_conn_id="semantic-search-airflow-sdk",
            autocommit=True,
            sql=[
                f"""
                            INSERT INTO airflow.Pod (job_id, pod_name, task_name)
                            SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
                            WHERE NOT EXISTS (
                                SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
                            ) LIMIT 1;
                        """
            ],
            task_concurrency=1,
            dag=dag,
            trigger_rule="all_done",
        )
    
    @task.branch
    def choose_update_job(job_id):
        print(job_id)
        if job_id:
            return "update_pod_name_rds"
        return "do_nothing"
    
    sql_task = update_job_pod_name(
        "{{ params.job_id}}",
        "process_name",
    )
    do_nothing = EmptyOperator(task_id="do_nothing")
    start_dag = EmptyOperator(task_id="start")
    end_dag = EmptyOperator(task_id="end", trigger_rule=TriggerRule.ONE_SUCCESS)
    
    (start_dag >> choose_update_job("{{ params.job_id }}") >> [sql_task, do_nothing] >> end_dag)