Search code examples
pythonmodelamazon-sagemakerhuggingface-transformers

Deploying huggingface zero-shot classification in Sagemaker using template returns error, missing positional argument 'candidate_labels'


I'm using the generated code from huggingface, Task: Zero-Shot Classification, Configuration: AWS and running it in Sagemaker's jupyterlab

from sagemaker.huggingface import HuggingFaceModel
import sagemaker

role = sagemaker.get_execution_role()
# Hub Model configuration. https://huggingface.co/models
hub = {
    'HF_MODEL_ID':'facebook/bart-large-mnli',
    'HF_TASK':'zero-shot-classification'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.6.1',
    pytorch_version='1.7.1',
    py_version='py36',
    env=hub,
    role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1, # number of instances
    instance_type='ml.m5.xlarge' # ec2 instance type
)

predictor.predict({
    'inputs': "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
})

The following error returned:

ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message "{ "code": 400, "type": "InternalServerException",
"message": "call() missing 1 required positional argument: \u0027candidate_labels\u0027" } ". See ... in account **** for more information.

I tried running them differently such as this,

predictor.predict({
    'inputs': "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!",
    'candidate_labels': ['science', 'life']
})

but still don't work. How should I run it?


Solution

  • The schema of request body for a zero-shot classification model is defined in this link.

    {
        "inputs": "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!",
        "parameters": {
            "candidate_labels": [
                "refund",
                "legal",
                "faq"
            ]
        }
    }