Search code examples
mysqljinja2airflowairflow-api

Using dagrun.conf on custom operator


I am using airflow 2 stable rest API to trigger a dag. I have created a custom DAG that takes input from MySQL (2 tables) join on key.

and in the body of API, I have to send some parameters like below and these parameters will decide which 2 tables to join.

conf": {"database_1":"test","table_1":"student","key_1":"id","database_2": "test","table_2": "college","key_2": "student_id"},

Below is the Custom Operator implementation.

from typing import Dict, Iterable, Mapping, Optional, Union

from airflow.models.baseoperator import BaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.mysql.operators.mysql import MySqlOperator
from airflow.utils.decorators import apply_defaults

class MySqlJoinOperator(BaseOperator):

    @apply_defaults
    def __init__(
            self,
            *,
            mysql_conn_id: str = 'mysql_default',
            parameters: Optional[Union[Mapping, Iterable]] = None,
            autocommit: bool = False,
            database_1: str = None,
            table_1: str = None,
            key_1: str = None,
            database_2: str = None,
            table_2: str = None,
            key_2: str = None,
            how: str = 'inner',
            **kwargs,) -> None:
            super().__init__(**kwargs)
            self.mysql_conn_id = mysql_conn_id
            self.autocommit = autocommit
            self.parameters = parameters
            self.database_1 = database_1
            self.table_1 = table_1
            self.key_1 = key_1
            self.database_2 = database_2
            self.table_2 = table_2
            self.key_2= key_2
            self.how = how

    def execute(self, context: Dict) -> None:
        self.log.info('Joining  {}.{} and {}.{}'.format(self.database_1,self.table_1,self.database_2,self.table_2))
        # hook_1 = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database_1)
        # hook_2 = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database_2)
        # sql_1 = "select * from {}".format(self.table_1)
        # sql_2 = "select * from {}".format(self.table_2)
        # dataframe_1 = hook_1.get_pandas_df(sql_1)
        # dataframe_2 = hook_2.get_pandas_df(sql_2)
        # resultant_dataframe = dataframe_1.join(dataframe_2,how=self.how)
        hook = MySqlHook(mysql_conn_id = self.mysql_conn_id)
        sql = "select * from {}.{} as table_1 join {}.{} as table_2 on table_1.{} = table_2.{}".format(self.database_1, self.table_1, self.database_2, self.table_2, self.key_1, self.key_2)
        resultant_dataframe = hook.get_records(sql)
        print(resultant_dataframe)
        return resultant_dataframe

and it will be used like this

from airflow import DAG
from airflow. utils.dates import days_ago

from mysql_join import MySqlJoinOperator

dag = DAG(
    dag_id='test_mysql_join',
    schedule_interval=None,
    start_date=days_ago(10),
    tags=['test mysql join'],
)

test_mysql_operator = MySqlJoinOperator(
    task_id='join_test',
    mysql_conn_id = "root_MYSQL_connection",
    database_1= "{{ dag_run.conf['database_1'] }}",
    table_1="{{ dag_run.conf['table_1'] }}",
    key_1="{{ dag_run.conf['key_1'] }}",
    database_2= "{{ dag_run.conf['database_2'] }}",
    table_2 = "{{ dag_run.conf['table_2'] }}",
    key_2 = "{{ dag_run.conf['key_2'] }}",
    dag=dag)

but the jinja template not working here. Can anyone help me with how can I achieve this?


Solution

  • Check the airflow documents on template fields in custom operators here. I believe you just need to add those fields to template_fields.