Search code examples
airflowairflow-2.xairflow-taskflow

Chaining reusable decorated tasks (taskflow api) airflow


I am currently experimenting with reusable airflow tasks.

I am basing my testing on the dags provided in the documentation of airflow

@task
def add_task(x, y):
    print(f"Task args: x={x}, y={y}")
    return x + y

@dag(start_date=DateTime(2022, 1, 1),schedule=None, catchup=False)
def mydag():
    start = add_task.override(task_id="start_0")(1, 2)
    for x, y, task_id_str in zip([1,3,5],[2,4,6],["start_1", "start_2", "start_3"], strict=True):
        start >> add_task.override(task_id=task_id_str)(x,y)

mydag()

Unfortunately with the current setup the tasks 1 to 3 are parallel.

Is there a way to chain those tasks as a downstream in a pythonic way? e.g t1 >> t2 >> t3 .. I can't seem to find any examples in the documentation.

Note, I do understand that I could use dynamic mapping, however, I want to see if I can do it without DM.

Resources: https://docs.astronomer.io/learn/managing-dependencies

https://airflow.apache.org/docs/apache-airflow/stable/tutorial/taskflow.html#reusing-a-decorated-task

I am expecting an end result that would look similar to

t0.set_downstream(t1)
t1.set_downstream(t2)
t2.set_downstream(t3)

parallel tasks


Solution

  • You can use the chain() method described in that Astronomer Learn doc you referenced. The trick is to create a list of the "start_*" tasks and then chain() the unpacked list.

    from datetime import datetime
    
    from airflow.decorators import dag, task
    from airflow.models.baseoperator import chain
    
    
    @task
    def add_task(x, y):
        print(f"Task args: x={x}, y={y}")
        return x + y
    
    @dag(start_date=datetime(2022, 1, 1),schedule=None, catchup=False)
    def mydag():
        start = add_task.override(task_id="start_0")(1, 2)
    
        zip_tasks = []
        for x, y, task_id_str in zip([1,3,5],[2,4,6],["start_1", "start_2", "start_3"], strict=True):
            zip_tasks.append(add_task.override(task_id=task_id_str)(x,y))
    
        chain(start, *zip_tasks)
    
    mydag()
    

    enter image description here