Search code examples
python-3.xdeploymentxgboostamazon-sagemaker

How to update an existing model in AWS sagemaker >= 2.0


I have an XGBoost model currently in production using AWS sagemaker and making real time inferences. After a while, I would like to update the model with a newer one trained on more data and keep everything as is (e.g. same endpoint, same inference procedure, so really no changes aside from the model itself)

The current deployment procedure is the following :

from sagemaker.xgboost.model import XGBoostModel
from sagemaker.xgboost.model import XGBoostPredictor

xgboost_model = XGBoostModel(
    model_data = <S3 url>,
    role = <sagemaker role>,
    entry_point = 'inference.py',
    source_dir = 'src',
    code_location = <S3 url of other dependencies>
    framework_version='1.5-1',
    name = model_name)

xgboost_model.deploy(
    instance_type='ml.c5.large',
    initial_instance_count=1,
    endpoint_name = model_name)

Now that I updated the model a few weeks later, I would like to re-deploy it. I am aware that the .deploy() method creates an endpoint and an endpoint configuration so it does it all. I cannot simply re-run my script again since I would encounter an error.

In previous versions of sagemaker I could have updated the model with an extra argument passed to the .deploy() method called update_endpoint = True. In sagemaker >=2.0 this is a no-op. Now, in sagemaker >= 2.0, I need to use the predictor object as stated in the documentation. So I try the following :

predictor = XGBoostPredictor(model_name)
predictor.update_endpoint(model_name= model_name)

Which actually updates the endpoint according to a new endpoint configuration. However, I do not know what it is updating... I do not specify in the above 2 lines of code that we need to considering the new xgboost_model trained on more data... so where do I tell the update to take a more recent model?

Thank you!

Update

I believe that I need to be looking at production variants as stated in their documentation here. However, their whole tutorial is based on the amazon sdk for python (boto3) which has artifacts that are hard to manage when I have different entry points for each model variant (e.g. different inference.py scripts or any other code that I want packaged with the model when it goes to production)


Solution

  • Since I found an answer to my own question I will post it here for those who encounter the same problem.

    I ended up re-writing the code of my deployment script using the boto3 SDK rather than the sagemaker SDK (or a mix of both as some documentation suggest which may even be outdated on the AWS website).

    Here's the whole script that shows how to create a sagemaker model object, an endpoint configuration and an endpoint to deploy the model on for the first time. In addition, it shows how to update the endpoint with a newer model and new model artifacts like a new inference script (which was the aim of my question)

    Here's the code to do all 3 in case you want to bring your own model and update it safely in production on sagemaker through the boto3 API:

    import boto3
    import time
    from datetime import datetime
    from sagemaker import image_uris
    from fileManager import *  # this is a local script for helper functions
    
    # name of zipped model and zipped inference code
    CODE_TAR = 'your_inference_code_and_other_artifacts.tar.gz'
    MODEL_TAR = 'your_saved_xgboost_model.tar.gz'
    
    # sagemaker params
    smClient = boto3.client('sagemaker')
    smRole = <your_sagemaker_role>
    bucket = sagemaker.Session().default_bucket()
    
    # deploy algorithm
    class Deployer:
    
        def __init__(self, modelName, deployRetrained=False):
            self.modelName=modelName
            self.deployRetrained = deployRetrained
            self.prefix = <S3_model_path_prefix>
        
        def deploy(self):
            '''
            Main method to create a sagemaker model, create an endpoint configuration and deploy the model. If deployRetrained
            param is set to True, this method will update an already existing endpoint.
            '''
            # define model name and endpoint name to be used for model deployment/update
            model_name = self.modelName + <any_suffix>
            endpoint_config_name = self.modelName + '-%s' %datetime.now().strftime('%Y-%m-%d-%HH%M')
            endpoint_name = self.modelName
            
            # deploy model for the first time
            if not self.deployRetrained:
                print('Deploying for the first time')
    
                # here you should copy and zip the model dependencies that you may have (such as preprocessors, inference code, config code...)
                # mine were zipped into the file called CODE_TAR
    
                # upload model and model artifacts needed for inference to S3
                uploadFile(list_files=[MODEL_TAR, CODE_TAR], prefix = self.prefix)
    
                # create sagemaker model and endpoint configuration
                self.createSagemakerModel(model_name)
                self.createEndpointConfig(endpoint_config_name, model_name)
    
                # deploy model and wait while endpoint is being created
                self.createEndpoint(endpoint_name, endpoint_config_name)
                self.waitWhileCreating(endpoint_name)
            
            # update model
            else:
                print('Updating existing model')
    
                # upload model and model artifacts needed for inference (here the old ones are replaced)
                # make sure to make a backup in S3 if you would like to keep the older models
                # we replace the old ones and keep the same names to avoid having to recreate a sagemaker model with a different name for the update!
                uploadFile(list_files=[MODEL_TAR, CODE_TAR], prefix = self.prefix)
    
                # create a new endpoint config that takes the new model
                self.createEndpointConfig(endpoint_config_name, model_name)
    
                # update endpoint
                self.updateEndpoint(endpoint_name, endpoint_config_name)
    
                # wait while endpoint updates then delete outdated endpoint config once it is InService
                self.waitWhileCreating(endpoint_name)
                self.deleteOutdatedEndpointConfig(model_name, endpoint_config_name)
    
        def createSagemakerModel(self, model_name):
            ''' 
            Create a new sagemaker Model object with an xgboost container and an entry point for inference using boto3 API
            '''
            # Retrieve that inference image (container)
            docker_container = image_uris.retrieve(region=region, framework='xgboost', version='1.5-1')
    
            # Relative S3 path to pre-trained model to create S3 model URI
            model_s3_key = f'{self.prefix}/'+ MODEL_TAR
    
            # Combine bucket name, model file name, and relate S3 path to create S3 model URI
            model_url = f's3://{bucket}/{model_s3_key}'
    
            # S3 path to the necessary inference code
            code_url = f's3://{bucket}/{self.prefix}/{CODE_TAR}'
            
            # Create a sagemaker Model object with all its artifacts
            smClient.create_model(
                ModelName = model_name,
                ExecutionRoleArn = smRole,
                PrimaryContainer = {
                    'Image': docker_container,
                    'ModelDataUrl': model_url,
                    'Environment': {
                        'SAGEMAKER_PROGRAM': 'inference.py', #inference.py is at the root of my zipped CODE_TAR
                        'SAGEMAKER_SUBMIT_DIRECTORY': code_url,
                    }
                }
            )
        
        def createEndpointConfig(self, endpoint_config_name, model_name):
            ''' 
            Create an endpoint configuration (only for boto3 sdk procedure) and set production variants parameters.
            Each retraining procedure will induce a new variant name based on the endpoint configuration name.
            '''
            smClient.create_endpoint_config(
                EndpointConfigName=endpoint_config_name,
                ProductionVariants=[
                    {
                        'VariantName': endpoint_config_name,
                        'ModelName': model_name,
                        'InstanceType': INSTANCE_TYPE,
                        'InitialInstanceCount': 1
                    }
                ]
            )
    
        def createEndpoint(self, endpoint_name, endpoint_config_name):
            '''
            Deploy the model to an endpoint
            '''
            smClient.create_endpoint(
                EndpointName=endpoint_name,
                EndpointConfigName=endpoint_config_name)
        
        def deleteOutdatedEndpointConfig(self, name_check, current_endpoint_config):
            '''
            Automatically detect and delete endpoint configurations that contain a string 'name_check'. This method can be used
            after a retrain procedure to delete all previous endpoint configurations but keep the current one named 'current_endpoint_config'.
            '''
            # get a list of all available endpoint configurations
            all_configs = smClient.list_endpoint_configs()['EndpointConfigs']
    
            # loop over the names of endpoint configs
            names_list = []
            for config_dict in all_configs:
                endpoint_config_name = config_dict['EndpointConfigName']
    
                # get only endpoint configs that contain name_check in them and save names to a list
                if name_check in endpoint_config_name:
                    names_list.append(endpoint_config_name)
            
            # remove the current endpoint configuration from the list (we do not want to detele this one since it is live)
            names_list.remove(current_endpoint_config)
    
            for name in names_list:
                try:
                    smClient.delete_endpoint_config(EndpointConfigName=name)
                    print('Deleted endpoint configuration for %s' %name)
                except:
                    print('INFO : No endpoint configuration was found for %s' %endpoint_config_name)
    
        def updateEndpoint(self, endpoint_name, endpoint_config_name):
            ''' 
            Update existing endpoint with a new retrained model
            '''
            smClient.update_endpoint(
                EndpointName=endpoint_name,
                EndpointConfigName=endpoint_config_name,
                RetainAllVariantProperties=True)
        
        def waitWhileCreating(self, endpoint_name):
            ''' 
            While the endpoint is being created or updated sleep for 60 seconds.
            '''
            # wait while creating or updating endpoint
            status = smClient.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
            print('Status: %s' %status)
            while status != 'InService' and status !='Failed':
                time.sleep(60)
                status = smClient.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']
                print('Status: %s' %status)
            
            # in case of a deployment failure raise an error
            if status == 'Failed':
                raise ValueError('Endpoint failed to deploy')
    
    if __name__=="__main__":
        deployer = Deployer('MyDeployedModel', deployRetrained=True)
        deployer.deploy()
    

    Final comments :

    • The sagemaker documentation mentions all this but fails to state that you can provide an 'entry_point' to the create_model method as well as a 'source_dir' for inference dependencies (e.g. normalization artifacts, inference scripts and so on). It can be done as seen in PrimaryContainer argument.

    • my fileManager.py script just contains basic functions to make tar files, upload and download to and from my S3 paths. To simplify the class, I have not included them in.

    • The method deleteOutdatedEndpointConfig may seem like a bit of an overkill with unnecessary loops and checks, I do so because I have multiple endpoint configurations to handle and wanted to remove the ones that weren't live AND contain the string name_check (I do not know the exact name of the configuration since there is a datetime suffix). Feel free to simplify it or remove it all together.

    Hope it helps.