Search code examples
pythonairflowmwaa

pull xcom data outside any operator in airflow


I need to pull data from xcom into a python variable which will be transformed using some regex and passed further. However I am not able to find anywhere how can I read data from xcom without using any operator (directly into python code). I am using MWAA on AWS with airflow 2.0.2 and playing around with below snippet.

s3Path = ""
def pull_from_xcom(**context):
        global s3Path
        msg = context['ti'].xcom_pull(task_ids='sqs', key='messages')
        s3Path = msg['Messages'][0]['Body']

    SQSRUN = SQSSensor(
    task_id='sqs',
    poke_interval=0,
    timeout=10,
    sqs_queue=SQS_URL,
    aws_conn_id=AWS)

    xcomGet = PythonOperator(
    task_id='xcom_pull',
    python_callable=pull_from_xcom,
    provide_context=True,
    depends_on_past=False)

    # s3Path Transformations
    para1 = re.findall(r"(para1=\w+)",s3Path)
    para2 = re.findall(r"(para2=\w+)",s3Path)

    sparkstep = #Constructing dict using para1 and para2 for spark job submission

    #Calling sparkStep
    sparkTransform = EmrAddStepsOperator(
            task_id='S3PathTransform',
            job_flow_id=Variable.get("EMR"),
            aws_conn_id=AWS,
            steps=sparkstep,
        )
        #Further tasks in dag

This does not works as the python operator will work after dag runs while i am using the s3Path transformed value into another operator before running the dag. I tried setting the s3Path value as variable and reading it but that does not work because that variable is not created when dag is uploaded.

I see that ti.xcom_pull(key=messages, task_ids='sqs') can be used to pull data from xcom but from where should I get ti? Is there any way to get task instance to work with xcom without using any operator.

Basically the question is how to get the value which SQSRUN sends to xcom. I am not able to find any documentation or online links on how to use the value which is fetched by SQSSensor. Would really appreciate some help.


Solution

  • I see that ti.xcom_pull(key=messages, task_ids='sqs') can be used to pull data from xcom but from where should I get ti?

    ti is passed down in the execution context. Your snippet demonstrates how that is done.

    Is there any way to get task instance to work with xcom without using any operator?

    Yes, you can get the xcom by similarly querying the database like Airflow does it.

    from airflow.utils.session import provide_session
    from airflow.models.xcom import XCom
    
    @provide_session
    def get_sqs_messages(session):
        query = XCom.get_many(
            key="messages",
            dag_ids="dag-id",
            task_ids="sqs",
            session=session,
            limit=1
        )
        # ensure the most recent value is retrieved.
        query = query.order_by("execution_date desc")
        xcom = query.with_entities(XCom.value).first()
    
        if xcom:
           return XCom.deserialize_value(xcom)
    

    In your snippet, you seemed to be setting global s3Path in your dag module and overriding its value in an operator. EmrAddStepsOperator is initialized when the module is parsed to the initial value bound to s3Path.

    There is a better way given your objective is to derive the steps value for initializing EmrAddStepsOperator from an xcom value,

    steps kwargs passed to EmrAddStepsOperator constructor is templated. This means you can provide Jinja2 template string for its value and this is transcluded during initialization of the task instance.

    sparkstep can be declared as:

    sparkstep = "{{sparkstep_from_messsages(ti.xcom_pull(task_ids='sqs', key='messages'))}}"
    
    sparkTransform = EmrAddStepsOperator(
                task_id='S3PathTransform',
                job_flow_id=Variable.get("EMR"),
                aws_conn_id=AWS,
                steps=sparkstep,
            )
    

    There the value pulled from xcom is passed to a function named sparkstep_from_messages defined as follows.

    def sparkstep_from_messages(messages):
        # s3Path Transformations
        para1 = re.findall(r"(para1=\w+)",s3Path)
        para2 = re.findall(r"(para2=\w+)",s3Path)
    
        sparkstep = #Constructing dict using para1 and para2 for spark job submission
        return sparkstep
    

    You must provide this function as a user_defined_macros in your DAG initialization so that it is available in the template context.

    user_defined_macros = dict(
        sparkstep_from_messages=sparkstep_from_messages
    )
    
    dag = DAG(dag_id="sample-dag", user_defined_macros=user_defined_macros)