Search code examples
google-cloud-composerairflowgcp-ai-platform-training

cleaning past ai platform model version with airflow


I am using airflow to schedule the training of a model version in gcloud AI platform I managed to schedule the training of the model, the creation of the version, then I set this last version as the default using this DAG:

with DAG('ml_pipeline', schedule_interval=None, default_args=default_args) as dag:


    uuid = str(uuid4())
    training_op = MLEngineTrainingOperator(
        task_id='submit_job_for_training',
        project_id=PROJECT_ID,
        job_id='training_{}'.format(uuid),
        # package_uris=TRAINER_BIN,
        package_uris=[os.path.join(TRAINER_BIN)],
        training_python_module=TRAINER_MODULE,
        runtime_version=RUNTIME_VERSION,
        region='us-central1',
        training_args=[
            '--base-dir={}'.format(BASE_DIR)
        ],
        python_version='3.5')

    create_version_op = MLEngineVersionOperator(
          task_id='create_version',
          project_id=PROJECT_ID,
          model_name=MODEL_NAME,
          version={
              'name': version_name,
              'deploymentUri': export_uri,
              'runtimeVersion': RUNTIME_VERSION,
              'pythonVersion': '3.5',
              'framework': 'SCIKIT_LEARN',
          },
          operation='create')

    set_version_default_op = MLEngineVersionOperator(
          task_id='set_version_as_default',
          project_id=PROJECT_ID,
          model_name=MODEL_NAME,
          version={'name': version_name},
          operation='set_default')
    training_op >> create_version_op >> set_version_default_op

I would like to clean the previous version of the model in this dag. I think I should use the "list" and "delete" operation of the MLEngineVersionOperator using something like this:

    list_model_versions = MLEngineVersionOperator(
        task_id="list_versions",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        operation="list",
    )

    delete_other_version = MLEngineVersionOperator(
        task_id="delete_precedent_version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        operation="delete",
        version={'name': some_name}
    )

I read about using Xcom to use the result of the list operator in the delete but I could not figure out how to do this.

Any advice or solution on how to proceed would be most appreciated. Thanks!


Solution

  • You can use a templated property to pass the result of a previous operator using Xcom. For example:

    delete_other_version = MLEngineVersionOperator(
            task_id="delete_precedent_version",
            project_id="asimov-foundation",
            model_name="IrisPredictor",
            version_name="{{task_instance.xcom_pull(task_ids='my_previous_task')}}",
            operation="delete",
        )
    

    Where the value of version_name is using a Jinja template to use Xcom. Now, the result of the previous operator is a list of versions, so you would need to do additional processing before passing it the name of the version to delete.

    Here's an example of a PythonOperator that takes the list from the previous operator and obtains the number of the second most recent version deployed.

    def get_version(**context):
        # Get the list of versions from previous operator
        versions = context['task_instance'].xcom_pull(task_ids='list_versions')
    
        # Sort the version list by createTime and obtain the name of the second most recent version
        full_name = sorted(versions, key=lambda x: x['createTime'], reverse=True)[1]['name']
    
        # The name is in format "projects/PROJECT/models/MODEL/versions/VERSION", so we'll use only VERSION
        return full_name.split('/')[-1]
    
    get_version_task = PythonOperator(
            task_id='get_version_task',
            python_callable=get_version,
            provide_context=True,
        )
    

    In PythonOperator is possible to use xcom_pull through the context.

    The full dag is:

    def get_version(**context):
        # Get the list of versions from previous operator
        versions = context['task_instance'].xcom_pull(task_ids='list_versions')
    
        # Sort the version list by createTime and obtain the name of the second most recent version
        full_name = sorted(versions, key=lambda x: x['createTime'], reverse=True)[1]['name']
    
        # The name is in format "projects/PROJECT/models/MODEL/versions/VERSION", so we'll use only VERSION
        return full_name.split('/')[-1]
    
    
    list_model_versions = MLEngineVersionOperator(
        task_id="list_versions",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        operation="list",
    )
    
    get_version_task = PythonOperator(
            task_id='get_version_task',
            python_callable=get_version,
            provide_context=True,
    )
    
    delete_other_version = MLEngineVersionOperator(
        task_id="delete_precedent_version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        version_name="{{task_instance.xcom_pull(task_ids='get_version_task')}}",
        operation="delete",
    )
    
    list_model_versions >> get_version_task >> delete_other_version