I am trying to run a batch predict job task using the Vertex AI Airflow Operator CreateBatchPredictionJobOperator
. This requires pulling a model id from XCom which was pushed by a previous custom container training job. However, CreateBatchPredictionJobOperator
doesn't seem to render Xcom pulls as expected.
I am running Airflow 2.3.0 on my local machine.
My code looks something like this:
batch_job_task = CreateBatchPredictionJobOperator(
gcp_conn_id="gcp_connection",
task_id="batch_job_task",
job_display_name=JOB_DISPLAY_NAME,
model_name="{{ ti.xcom_pull(key='model_conf')['model_id'] }}",
predictions_format="bigquery",
bigquery_source=BIGQUERY_SOURCE,
region=REGION,
project_id=PROJECT_ID,
machine_type="n1-standard-2",
bigquery_destination_prefix=BIGQUERY_DESTINATION_PREFIX,
This results in a value error when the task runs:
ValueError: Resource {{ ti.xcom_pull(key='model_conf')['model_id'] }} is not a valid resource id.
The expected behaviour would be to pull that variable by key and render it as a string.
I can also confirm that I am able to see the model id (and other info) in XCom by navigating there in the UI. I attempted using the same syntax with xcom_pull
with a PythonOperator
and it works.
def print_xcom_value(value):
print("VALUE:", value)
print_xcom_value_by_key = PythonOperator(
task_id="print_xcom_value_by_key", python_callable=print_xcom_value,
op_kwargs={"value": "{{ ti.xcom_pull(key='model_conf')['model_id'] }}" },
provide_context=True,
)
> [2022-12-15, 13:11:19 UTC] {logging_mixin.py:115} INFO - VALUE: 3673414612827265024
CreateBatchPredictionJobOperator
does not accept provide_context
as a variable. I assumed it would render xcom pulls by default since xcom pulls are used in the CreateBatchPredictionJobOperator
in an example on the Airflow docs (link here).
Is there any way I can provide context to this Vertex AI Operator to pull from the XCom storage?
Is something wrong with the syntax that I am not seeing? Anything I a misunderstanding in the docs?
UPDATE:
One thing that confuses me is that model_name
is a templated field according to the Airflow docs (link here) but the field is not rendering the XCom template.
Did you set render_template_as_native_obj=True
in your DAG definition?
What version of apache-airflow-providers-google
do you use?
====
From OP: Your answer was a step in the right direction.
The solution was to upgrade apache-airflow-providers-google
to the latest version (at the moment, this is 8.6.0). I'm not able to pinpoint exactly where in the changelog this fix is mentioned.
Setting render_template_as_native_obj=True
was not useful for this issue since it rendered the id pulled from XCom as an int
, and I found no proper way to convert it to str
when passed into CreateBatchPredictionJobOperator
in the model_name
arg.