Search code examples
pythonbatch-processingpredictmultilabel-classificationautoml

batch predictions google automl via python


I'm pretty new using stackoverflow as well as using the google cloud platform, so apologies if am not asking this question in the right format. I am currently facing an issue with getting the predictions from my model.

I've trained a multilabel automl model on the google cloud platform and and now i want to use that model to score out new data entries. Since the platform only allows one entry at the same time i want to make use of python to do batch predictions.

I've stored my data entries in seperate .txt files on the google cloud bucket and created a .txt file where i'm listing the gs:// references to those files (like they recommend in the documentation).

I've exported a .json file with my credentials from the service account and specified the id's and paths in my code:

# import API credentials and specify model / path references
path = 'xxx.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = path

model_name = 'xxx'
model_id = 'TCN1234567890'
project_id = '1234567890'
model_full_id = f"https://eu-automl.googleapis.com/v1/projects/{project_id}/locations/eu/models/{model_id}"
input_uri = f"gs://bucket_name/{model_name}/file_list.txt"
output_uri = f"gs://bucket_name/{model_name}/outputs/"
prediction_client = automl.PredictionServiceClient()

And then i'm running the following code to get the predictions:

# score batch of file_list
gcs_source = automl.GcsSource(input_uris=[input_uri])

input_config = automl.BatchPredictInputConfig(gcs_source=gcs_source)
gcs_destination = automl.GcsDestination(output_uri_prefix=output_uri)
output_config = automl.BatchPredictOutputConfig(
    gcs_destination=gcs_destination
)

response = prediction_client.batch_predict(
    name=model_full_id,
    input_config=input_config,
    output_config=output_config
)

print("Waiting for operation to complete...")
print(
    f"Batch Prediction results saved to Cloud Storage bucket. {response.result()}"
)

However, i'm getting the following error: InvalidArgument: 400 Request contains an invalid argument.

print screen of the error

Would anyone have a hince what is causing this issue? Any input would be appreciated! Thanks!


Solution

  • Found the issue!

    I needed to set the client to the 'eu' environment first:

    options = ClientOptions(api_endpoint='eu-automl.googleapis.com')
    prediction_client = automl.PredictionServiceClient(client_options=options)