Search code examples
pythonsqlpostgresqlairflowdirected-acyclic-graphs

How to run Airflow dag with conn_id in config template by PostgresOperator?


I have a Airflow dag with a PostgresOperator to execute a SQL query. I want to switch to my test database or my prod database with config (run w/config). But postgres_conn_id is not a template field and so PostgresOperator say "{{ dag_run.conf.get('CONN_ID_TEST', 'pg_database') }}" is not a connection. I run this script with {"CONN_ID_TEST": "pg_database_test"} config.

I try to create a custom postgresql operator with the same code of Airflow github and I add template_fields: Sequence[str] = ("postgres_conn_id",) at the top of my class CustomPostgresOperator but that doesn't work too (same error).

I have two conn_id env variables :

  • AIRFLOW_CONN_ID_PG_DATABASE (prod)
  • AIRFLOW_CONN_ID_PG_DATABASE_TEST (test)

My script looks like :

from airflow import DAG
from airflow.providers.postgres.operators.postgres import PostgresOperator
from airflow.operators.dummy import DummyOperator

DAG_ID = "init_database"
POSTGRES_CONN_ID = "{{ dag_run.conf.get('CONN_ID_TEST', 'pg_database') }}"

with DAG(
    dag_id=DAG_ID, 
    description="My dag",
    schedule_interval="@once",
    start_date=dt.datetime(2022, 1, 1),
    catchup=False,
    ) as dag:

    start = DummyOperator(task_id = 'start')
  
    my_task = PostgresOperator(                     #### OR CustomPostgresOperator
        task_id="select",
        sql="SELECT * FROM pets LIMIT 1;",
        postgres_conn_id=POSTGRES_CONN_ID,
        autocommit=True
        )

start >> my task

How I can process to solve my problem ? And if is not possible how I can switch my PostgresOperator connection to my dev database without recreate an other DAG script ?

Thanks, Léo


Solution

  • Subclassing is a solid way to modify the template_fields how you wish. Since template_fields is a class attribute your subclass only really needs to be the following (assuming you're just adding the connection ID to the existing template_fields):

    from airflow.providers.postgres.operators.postgres import PostgresOperator as _PostgresOperator
    
    
    class PostgresOperator(_PostgresOperator):
        template_fields = [*_PostgresOperator.template_fields, "conn_id"]
    

    The above is using Postgres provider version 5.3.1 which actually uses the Common SQL provider under the hood so the connection parameter is actually conn_id. (template_fields refer to the instance attribute name rather than the parameter name.)

    For example, assume the below DAG gets triggered with a run config of {"environment": "dev"}:

    from pendulum import datetime
    
    from airflow.decorators import dag
    from airflow.providers.postgres.operators.postgres import PostgresOperator as _PostgresOperator
    
    
    class PostgresOperator(_PostgresOperator):
        template_fields = [*_PostgresOperator.template_fields, "conn_id"]
    
    
    @dag(start_date=datetime(2023, 1, 1), schedule=None)
    def template_postgres_conn():
        PostgresOperator(task_id="run_sql", sql="SELECT 1;", postgres_conn_id="{{ dag_run.conf['environment'] }}")
    
    
    template_postgres_conn()
    

    Looking at the task log, the connection ID of "dev" is used to execute the SQL: enter image description here