Search code examples
pythonbigdataairflow

Airflow - Stop DAG based on condition (skip remaining tasks after branch)


I am new on airflow, so I have a doubt here.

I wanna run a DAG if a condition on first task is satisfied. If the condition is not satisfied I wanna to stop the dag after the first task.

Example:

# first task
def get_number_func(**kwargs):

    number = randint(0, 10)
    print(number)
    
    if (number >= 5):
        print('A')
        return 'continue_task'
    else:
        #STOP DAG
        
# second task if number is higher or equal 5
def continue_func(**kwargs):
    print("The number is " + str(number))
    
# first task declaration
start_op = BranchPythonOperator(
    task_id='get_number',
    provide_context=True,
    python_callable=get_number_func,
    op_kwargs={},
    dag=DAG,
)

# second task declaration
continue_op = PythonOperator(
    task_id='continue_task',
    provide_context=True,
    python_callable=continue_func,
    op_kwargs={},
    dag=DAG,
)

start_op  >> continue_op 

I only run the second task if the condition of number is satisfied. In case of condition is not verified the DAG should not run the second task.

How can I perform that? I don't wanna use xcom, global variables or a dummy task.

Thanks in advance!


Solution

  • Have you checked out the ShortCircuitOperator? This task controls your task flow depending on whether a condition is True or False. If the condition is True, the downstream tasks will continue. Otherwise, all downstream tasks are skipped. Try changing your first task to a ShortCircuitOperator and update the get_number_func function to return True or False.

    Here was my test using your code:

    from airflow.decorators import dag, task
    from airflow.models import DAG
    from airflow.operators.python import PythonOperator, ShortCircuitOperator
    
    from datetime import datetime
    
    
    default_args = dict(
        start_date=datetime(2021, 4, 26),
        owner="me",
        retries=0,
    )
    
    dag_args = dict(
        dag_id="short_circuit",
        schedule_interval=None,
        default_args=default_args,
        catchup=False,
    )
    
    
    def get_number_func(**kwargs):
        from random import randint
    
        number = randint(0, 10)
        print(number)
    
        if number >= 5:
            print("A")
            return True
        else:
            # STOP DAG
            return False
    
    
    def continue_func(**kwargs):
        pass
    
    
    with DAG(**dag_args) as dag:
        # first task declaration
        start_op = ShortCircuitOperator(
            task_id="get_number",
            provide_context=True,
            python_callable=get_number_func,
            op_kwargs={},
        )
    
        # second task declaration
        continue_op = PythonOperator(
            task_id="continue_task",
            provide_context=True,
            python_callable=continue_func,
            op_kwargs={},
        )
    
        start_op >> continue_op