Search code examples
pythonmachine-learningnlpgoogle-cloud-vertex-ailarge-language-model

How to Set Safety Parameters for Text Generation Model in Google Cloud Vertex AI?


I am working on a research project where I need to summarize news articles using the Google Palm2 Text Generation Model. I have encountered an issue with certain news articles in my dataset where I'm getting empty responses along with safety attributes that block the output. Here is the code I'm using:

from vertexai.language_models import TextGenerationModel
parameters = {  # default values
    'max_output_tokens': 256,
    'temperature': 0.0,
    'top_p': 1.0,
    'top_k': 40,
}
prompt = "..."
model = TextGenerationModel.from_pretrained('text-bison@001')
response = model.predict(
    prompt,
    **parameters,
)

The following is an example prediction:

Prediction(predictions=[{'content': '', 'citationMetadata': None, 'safetyAttributes': {'blocked': True, 'errors': [253.0]}}], deployed_model_id='', model_version_id='', model_resource_name='', explanations=None)

The issue seems to be related to safety parameters preventing the model from generating a summary for certain news articles. I've been trying to find documentation on how to configure these safety parameters using the Python API, but I could not locate the relevant information.

Could someone please provide guidance on how to set the safety parameters for the TextGenerationModel? Any help or pointers to documentation would be greatly appreciated. Thank you!


Solution

  • I'm not sure about Vertex AI but you can set the safety_settings of the PaLM model (from google generative AI) by the following:

    import google.generativeai as palm
    
    completion = palm.generate_text(
        model=model,
        prompt=prompt,
        safety_settings=[
            {
                "category": safety_types.HarmCategory.HARM_CATEGORY_DEROGATORY,
                "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
            },
            {
                "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE,
                "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
            },
        ]
    ) 
    

    You should checkout this guide to get complete details of the safety catalogue and how to set threshold for each category as there are multiple categories and different threshold levels.

    NOTE: To use the PaLM API from generative AI, you'd need to install it first via:

    pip install -q google-generativeai
    

    and then set an API key which you'll get from here:

    import google.generativeai as palm
    palm.configure(api_key='YOUR_API_KEY')
    

    and then to access the same text-bison-001 model:

    models = [m for m in palm.list_models() if 'generateText' in m.supported_generation_methods]
    model = models[0].name # use this model on the first code snippet
    print(model) # prints 'models/text-bison-001'