Search code examples
airflowjinja2

How to parse airflow variables in custom sensor


I am building a custom sensor. To get the variables such as data_interval_start, data_interval_end, I am using context.

   def poke(self, context):
        data_interval_start= context['data_interval_start']
        data_interval_end= context['data_interval_end']
        query = f"""
        SELECT 
            count(*) as count 
        FROM `mytable`
        WHERE mydate >= "{data_interval_start}" and mydate < "{data_interval_end}"
        """
   ...

This process looks quiet odd and if I need other variables from context, I need to define them every time as a variable.

I believe there should be some easy way which can directly update the jinja template without extra code. Such as, my mysql string is:

        SELECT 
            count(*) as count 
        FROM `mytable`
        WHERE mydate >= "{{data_interval_start}}" and mydate < "{{data_interval_end}}"

When I use it directly with MySQLToGCSOperator, the variables get updated in sql during execution. How can I achieve same without manual f"" string?


Solution

  • I found the answer. Templating.

    Since I had to render the the query. I will need to pass the query as parameter in the constructor. And then mention the query parameter in template_fields.

    class CustomMySQLSensor(BaseSensorOperator):
        template_fields = ('query',)
        def __init__(self, mysql_conn_id: str, query: str, *args, **kwargs):
        ...