Search code examples
airflowjinja2batch-processingairflow-2.xgoogle-cloud-vertex-ai

Vertex AI Airflow Operators don't render XCom pulls (specifically CreateBatchPredictionJobOperator)


I am trying to run a batch predict job task using the Vertex AI Airflow Operator CreateBatchPredictionJobOperator. This requires pulling a model id from XCom which was pushed by a previous custom container training job. However, CreateBatchPredictionJobOperator doesn't seem to render Xcom pulls as expected.

I am running Airflow 2.3.0 on my local machine.

My code looks something like this:

batch_job_task = CreateBatchPredictionJobOperator(
        gcp_conn_id="gcp_connection",
        task_id="batch_job_task",
        job_display_name=JOB_DISPLAY_NAME,
        model_name="{{ ti.xcom_pull(key='model_conf')['model_id'] }}",
        predictions_format="bigquery",
        bigquery_source=BIGQUERY_SOURCE,
        region=REGION,
        project_id=PROJECT_ID,
        machine_type="n1-standard-2", 
        bigquery_destination_prefix=BIGQUERY_DESTINATION_PREFIX,

This results in a value error when the task runs:

ValueError: Resource {{ ti.xcom_pull(key='model_conf')['model_id'] }} is not a valid resource id.

The expected behaviour would be to pull that variable by key and render it as a string.

I can also confirm that I am able to see the model id (and other info) in XCom by navigating there in the UI. I attempted using the same syntax with xcom_pull with a PythonOperator and it works.

def print_xcom_value(value):
    print("VALUE:", value)

print_xcom_value_by_key = PythonOperator(
        task_id="print_xcom_value_by_key", python_callable=print_xcom_value,
        op_kwargs={"value": "{{ ti.xcom_pull(key='model_conf')['model_id'] }}" },
        provide_context=True,
        )

> [2022-12-15, 13:11:19 UTC] {logging_mixin.py:115} INFO - VALUE: 3673414612827265024

CreateBatchPredictionJobOperator does not accept provide_context as a variable. I assumed it would render xcom pulls by default since xcom pulls are used in the CreateBatchPredictionJobOperator in an example on the Airflow docs (link here).

  1. Is there any way I can provide context to this Vertex AI Operator to pull from the XCom storage?

  2. Is something wrong with the syntax that I am not seeing? Anything I a misunderstanding in the docs?

UPDATE: One thing that confuses me is that model_name is a templated field according to the Airflow docs (link here) but the field is not rendering the XCom template.


Solution

  • Did you set render_template_as_native_obj=True in your DAG definition? What version of apache-airflow-providers-google do you use?

    ====

    From OP: Your answer was a step in the right direction.

    The solution was to upgrade apache-airflow-providers-google to the latest version (at the moment, this is 8.6.0). I'm not able to pinpoint exactly where in the changelog this fix is mentioned.

    Setting render_template_as_native_obj=True was not useful for this issue since it rendered the id pulled from XCom as an int, and I found no proper way to convert it to str when passed into CreateBatchPredictionJobOperator in the model_name arg.