Search code examples
airflowoperatorsairflow-2.xairflow-taskflow

Task Mapping over a task group without using decorator


I have implemented a task group that is expected to be reused across multiple DAGs, in one of which utilizing it in a mapping manner makes more sense. Here is the full code of my task group:

from airflow.utils.task_group import TaskGroup
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.operators.email import EmailOperator
from airflow.providers.sftp.operators.sftp import SFTPOperator
from airflow.providers.sftp.hooks.sftp import SFTPHook, SSHHook

def DeliveryGroup(group_id: str, file:str, deliver_args:dict, **kwargs) -> TaskGroup:
    with TaskGroup(group_id=group_id, **kwargs) as tg:
        # select destination type
        selector_task = BranchPythonOperator(
            task_id='destination_selector',
            python_callable=lambda: f"{deliver_args.get('type')}"
        )
        
        email_task = EmailOperator(
            task_id="email",
            to=deliver_args.get('to'),
            subject=deliver_args.get('subject'),
            cc=deliver_args.get('cc'),
            html_content=deliver_args.get('body'),
            files=[file]
        )

        sftp_task = SFTPOperator(
            task_id="sftp",
            # ssh_conn_id='YinzCam-Connection',
            sftp_hook=SFTPHook(
                remote_host=deliver_args.get('host'),
                username=deliver_args.get('username'),
                password=deliver_args.get('password'),
                port=deliver_args.get('port', 22)),
            local_filepath=[file],
            remote_filepath=[deliver_args.get('path')]
        )

        selector_task >> [email_task, sftp_task]

        return tg

What I would do next is to pass a list of dicts that represent separate destinations as the expansive parameter of this task group.

task3 = DeliveryGroup.partial(
                group_id='deliver',
                file = "my_file.csv",
            ).expand(
                args=dag.default_args.get('destinations') # a list of dicts
            )

However, I received this error: AttributeError: 'function' object has no attribute 'partial'. So what is the correct way to write a mapping over a task group without using a decorator?

A guide of syntax, references


Solution

  • The problem:

    DeliveryGroup is a function in your case(not instance, not module etc)

    So it's something like:

    def my_func():
        return True
    
    # my_func is a function. the function knows nothing about partial()
    my_func.partial()  # AttributeError: 'function' object has no attribute 'partial'
    

    If I understand correctly what you need, here is an example:

    def create_group(group_id: str, file: str, task_params: list[dict], group_params: dict = None) -> TaskGroup:
        group_params = group_params or {}
        
        with TaskGroup(group_id=group_id, **group_params) as group:
            # iterate through the task parameters, create operators dynamically and attach to the group
            for ix, params in enumerate(task_params):
                selector_task = BranchPythonOperator(
                    task_id=f'destination_selector_{ix}',
                    python_callable=lambda: f"{params.get('type')}"
                )
    
                email_task = EmailOperator(
                    task_id=f'email_{ix}',
                    to=params.get('to'),
                    subject=params.get('subject'),
                    cc=params.get('cc'),
                    html_content=params.get('body'),
                    files=[file],
                )
    
                sftp_task = SFTPOperator(
                    task_id=f'sftp_{ix}',
                    # ssh_conn_id='YinzCam-Connection',
                    sftp_hook=SFTPHook(
                        remote_host=params.get('host'),
                        username=params.get('username'),
                        password=params.get('password'),
                        port=params.get('port', 22),
                    ),
                    local_filepath=[file],
                    remote_filepath=[params.get('path')],
                )
    
                selector_task >> [email_task, sftp_task]
    
            return group
    
    
    deliver = create_group(
        group_id='deliver',
        file='my_file.csv',
        group_params=dict(ui_fgcolor='#E4BE4D'),
        task_params=[
            dict(
                type='test',
                to=['[email protected]'],
                cc=['[email protected]'],
                body='hello world',
                host='host.so.com',
                username='username1',
                password='password1',
            ),
            dict(
                type='test2',
                to=['[email protected]'],
                cc=['[email protected]'],
                body='hello world2',
                host='host2.so.com',
                username='username2',
                password='password2',
            ),
            # ...
        ]
    )