def skip_update_job_pod_name(dag):
"""
:param dag: Airflow DAG
:return: Dummy operator to skip update pod name
"""
return DummyOperator(task_id="skip_update_job_pod_name", dag=dag)
def update_pod_name_branch_operator(dag: DAG, job_id: str):
"""branch operator to update pod name."""
return BranchPythonOperator(
dag=dag,
trigger_rule="all_done",
task_id="update_pod_name",
python_callable=update_pod_name_func,
op_kwargs={"job_id": job_id},
)
def update_pod_name_func(job_id: Optional[str]) -> str:
"""function for update pod name."""
return "update_job_pod_name" if job_id else "skip_update_pod_name"
def update_job_pod_name(dag: DAG, job_id: str, process_name: str) -> MySqlOperator:
"""
:param dag: Airflow DAG
:param job_id: Airflow Job ID
:param process_name: name of the current running process
:return: MySqlOperator to update Airflow job ID
"""
return MySqlOperator(
task_id="update_job_pod_name",
mysql_conn_id="semantic-search-airflow-sdk",
autocommit=True,
sql=[
f"""
INSERT INTO airflow.Pod (job_id, pod_name, task_name)
SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
WHERE NOT EXISTS (
SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
) LIMIT 1;
"""
],
task_concurrency=1,
dag=dag,
trigger_rule="all_done",
)
def create_k8s_pod_operator_without_volume(dag: DAG,
job_id: int,
....varaible) -> TaskGroup:
"""
Create task group for k8 operator without volume
"""
with TaskGroup(group_id="k8s_pod_operator_without_volume", dag=dag) as eks_without_volume_group:
emit_pod_name_branch = update_pod_name_branch_operator(dag=dag, job_id=job_id)
update_pod_name = update_job_pod_name(dag=dag, job_id=job_id, process_name=process_name)
skip_update_pod_name = skip_update_job_pod_name(dag=dag)
emit_pod_name_branch >> [update_pod_name, skip_update_pod_name]
return eks_without_volume_group
I update the code based on the comment, I am curious how does the taskgroup work with branch operator I will get this when I try to do this airflow.exceptions.AirflowException: Branch callable must return valid task_ids. Invalid tasks found: {'update_job_pod_name'}
You can use BranchPythonOperator that get the value and return which the name of task to run in any condition.
def choose_job_func(job_id):
if job_id:
return "update_pod_name_rds"
choose_update_job =BranchPythonOperator(task_id="choose_update_job", python_callable=choose_job_func,
op_kwargs={"job_id": "{{ params.job_id }}"})
or, in task flow api it would look like this :
@task.branch
def choose_update_job(job_id):
if job_id:
return "update_pod_name_rds"
Full Dag Example :
with DAG(
dag_id="test_dag",
start_date=datetime(2022, 1, 1),
schedule_interval=None,
render_template_as_native_obj=True,
params={
"job_id": Param(default=None, type=["null", "string"])
},
tags=["test"],) as dag:
def update_job_pod_name(job_id: str, process_name: str):
return MySqlOperator(
task_id="update_pod_name_rds",
mysql_conn_id="semantic-search-airflow-sdk",
autocommit=True,
sql=[
f"""
INSERT INTO airflow.Pod (job_id, pod_name, task_name)
SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
WHERE NOT EXISTS (
SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
) LIMIT 1;
"""
],
task_concurrency=1,
dag=dag,
trigger_rule="all_done",
)
@task.branch
def choose_update_job(job_id):
print(job_id)
if job_id:
return "update_pod_name_rds"
return "do_nothing"
sql_task = update_job_pod_name(
"{{ params.job_id}}",
"process_name",
)
do_nothing = EmptyOperator(task_id="do_nothing")
start_dag = EmptyOperator(task_id="start")
end_dag = EmptyOperator(task_id="end", trigger_rule=TriggerRule.ONE_SUCCESS)
(start_dag >> choose_update_job("{{ params.job_id }}") >> [sql_task, do_nothing] >> end_dag)