Search code examples
pythonairflowrollback

access airflow task arguments in the on_failure_callback function


I need a rollback operation to happen when a certain airflow task fails. To know what to rollback I need access to the task arguments inside the rollback function. The rollback function is passed to the on_failure_callback argument when defining the task.

Take this as a simplified example:

from airflow.decorators import dag, task
from airflow.utils.dates import days_ago


def rollback(context: dict):
    print("How do I access the 'task_argument' value?")

@task(on_failure_callback=rollback)
def example_task(task_argument: str) -> None:
    assert False
    
@dag(
    schedule_interval=None,
    start_date=days_ago(1),
)
def example_dag() -> None:
    example_task("the task argument's value.")
    
example_dag()

How do I get the value that was passed to the example_task inside the on_failure_callback? I'm sure it's hiding in the context variable but I have not been able to find clear documentation on what is inside context. context does contain a field params but that does not contain task_argument.


Solution

  • This code snippet worked for me.. Basically if you are using @task decorator you need to specify the context variable in the function arguments. per this doc page https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/taskflow.html#taskflow

    from airflow.decorators import dag, task
    from airflow.utils.dates import days_ago
    
    
    def rollback(context: dict):
        print("How do I access the 'task_argument' value?")
        print(context.get('params'))
    
    @task(on_failure_callback=rollback)
    def example_task(params: dict) -> None:
        assert False
        
    @dag(
        schedule_interval=None,
        start_date=days_ago(1),
    )
    def example_dag() -> None:
        example_task(params={'mytask_param' : "the task argument's value."})
        
    example_dag()