Search code examples
pythonamazon-web-servicesobject-detectionamazon-sagemaker

SageMaker Managed Spot Training with Object Detection algorithm


I'm trying to train an object detection model starting from an existing model using the new Managed Spot Training feature, The paramters used when creating my Estimator are as follows:

od_model = sagemaker.estimator.Estimator(get_image_uri(sagemaker.Session().boto_region_name, 'object-detection', repo_version="latest"),
                                         Config['role'],
                                         train_instance_count = 1,
                                         train_instance_type = 'ml.p3.16xlarge',
                                         train_volume_size = 50,
                                         train_max_run = (48 * 60 * 60),
                                         train_use_spot_instances = True,
                                         train_max_wait = (72 * 60 * 60),
                                         input_mode = 'File',
                                         checkpoint_s3_uri = Config['train_checkpoint_uri'],
                                         output_path = Config['s3_output_location'],
                                         sagemaker_session = sagemaker.Session()
                                         )

(The references to Config in the above are a config data structure I'm using to extract/centralise some parameters)

When I run the above, I get the following exception:

botocore.exceptions.ClientError: An error occurred (ValidationException) when calling the CreateTrainingJob operation: MaxWaitTimeInSeconds above 3600 is not supported for the given algorithm.

If I change train_max_wait to 3600 I get this exception instead:

botocore.exceptions.ClientError: An error occurred (ValidationException) when calling the CreateTrainingJob operation: Invalid MaxWaitTimeInSeconds. It must be present and be greater than or equal to MaxRuntimeInSeconds

However changing max_run_time to 3600 or less isn't going to work for me as I expect this model to take several days to train (large data set), in fact a single epoch takes more than an hour.

The AWS blog post on Managed Spot Training say that MaxWaitTimeInSeconds is limited to an 60 minutes for:

For built-in algorithms and AWS Marketplace algorithms that don’t use checkpointing, we’re enforcing a maximum training time of 60 minutes (MaxWaitTimeInSeconds parameter).

Earlier, the same blog post says:

Built-in algorithms: computer vision algorithms support checkpointing (Object Detection, Semantic Segmentation, and very soon Image Classification).

So I don't think it's that my algorithm doesn't support Checkpointing. In fact that blog post uses object detection and max run times of 48 hours. So I don't think it's an algorithm limitation.

As you can see above, I've set up a S3 URL for the checkpoints. The S3 bucket does exist, and the training container has access to it (it's the same bucket that the training data and model outputs are placed, and I had no problems with access to those before turning on spot training.

My boto and sagemaker libraries are current versions:

boto3 (1.9.239)
botocore (1.12.239)
sagemaker (1.42.3)

As best I can tell from reading various docs, I've got everything set up correctly. My use case is almost exactly what's described in the blog post linked above, but I'm using the SageMaker Python SDK instead of the console.

I'd really like to try Managed Spot Training to save some money, as I have a very long training run coming up. But limiting timeouts to an hour isn't going to work for my use case. Any suggestions?

Update: If I comment out the train_use_spot_instances and train_max_wait options to train on regular on-demand instances my training job is created successfully. If I then try to use the console to clone the job and turn on Spot instances on the clone I get the same ValidationException.


Solution

  • I ran my script again today and it worked fine, no botocore.exceptions.ClientError exceptions. Given that this issue affected both the Python SDK for Sagemaker and the console, I suspect it might have been an issue with the backend API and not my client code.

    Either way, it's working now.