Search code examples
javaspring-bootgoogle-cloud-vertex-ai

Spring Boot - Google Vertex AI authentication


I would like to use the Google's Vertex AI model (gemini-1.0-pro-vision-001) from a Spring Boot based Java backend application.

I would like to use a service account key, to grant access to my application to Vertex AI.

Maven:

<dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>com.google.cloud</groupId>
            <artifactId>libraries-bom</artifactId>
            <version>26.32.0</version>
            <type>pom</type>
            <scope>import</scope>
        </dependency>
    </dependencies>
</dependencyManagement>

<dependency>
    <groupId>com.google.cloud</groupId>
    <artifactId>google-cloud-vertexai</artifactId>
</dependency>

Google credentials:

The credential is working fine, while I'm also using it to access a private cloud storage, but because of something it is not working when I would like to use the Vertex AI.

@Lazy
@Bean
GoogleCredentials googleCredentials() throws IOException {
   return GoogleCredentials.fromStream(new ClassPathResource(this.config.getCredential().getPath()).getInputStream());
}

Vertex AI config:

  @Lazy
  @Bean
  VertexAI vertexAI(GoogleCredentials googleCredentials) throws IOException {
      return new VertexAI(this.config.getProjectId(), this.config.getGemini().getLocation(), googleCredentials);
  }

Generative model:

  @Lazy
  @Bean
  GenerativeModel geminiProVision(VertexAI vertexAI) {
    final GenerationConfig generationConfig = GenerationConfig.newBuilder()
        .setMaxOutputTokens(2048)
        .setTemperature(0.4F)
        .setTopK(32)
        .setTopP(1)
        .build();

    return new GenerativeModel("gemini-1.0-pro-vision-001", generationConfig, vertexAI);
  }

Service account permissions:

enter image description here

Use case:

      @Lazy
  @Autowired
  private GenerativeModel geminiProVision;

  public String execute(String productName, String businessId) {
    try {
      final List<SafetySetting> safetySettings = Arrays.asList(SafetySetting.newBuilder()
          .setCategory(HarmCategory.HARM_CATEGORY_HATE_SPEECH)
          .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE)
          .build(), SafetySetting.newBuilder()
              .setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
              .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE)
              .build(), SafetySetting.newBuilder()
                  .setCategory(HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
                  .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE)
                  .build(), SafetySetting.newBuilder()
                      .setCategory(HarmCategory.HARM_CATEGORY_HARASSMENT)
                      .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE)
                      .build());

      final List<Content> contents = new ArrayList<>();
      contents.add(Content.newBuilder()
          .setRole("user")
          .addParts(Part.newBuilder()
              .setText(
                  "This is a test question to Google's Vertex AI"))
          .build());

      final ResponseStream<GenerateContentResponse> responseStream = this.geminiProVision.generateContentStream(contents, safetySettings);

      responseStream.stream().forEach(System.out::println);

      return "";
    } catch (final Exception e) {
      log.error("Gemini was not able to respond!", e);
      return null;
    }
  }

Google APIs: enter image description here enter image description here

Problem: As I described above, the credential is working fine when I use it to access a cloud storage, but I get the following error when I would like to use the Vertex AI:

        com.google.api.gax.rpc.UnauthenticatedException: io.grpc.StatusRuntimeException: UNAUTHENTICATED: Request had invalid authentication credentials. Expected OAuth 2 access token, login cookie or other valid authentication credential. See https://developers.google.com/identity/sign-in/web/devconsole-project.
        at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:116)
        at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:41)
        at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:86)
        at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:66)
        at com.google.api.gax.grpc.ExceptionResponseObserver.onErrorImpl(ExceptionResponseObserver.java:82)
        at com.google.api.gax.rpc.StateCheckingResponseObserver.onError(StateCheckingResponseObserver.java:84)
        at com.google.api.gax.grpc.GrpcDirectStreamController$ResponseObserverAdapter.onClose(GrpcDirectStreamController.java:148)
        at io.grpc.PartialForwardingClientCallListener.onClose(PartialForwardingClientCallListener.java:39)
        at io.grpc.ForwardingClientCallListener.onClose(ForwardingClientCallListener.java:23)
        at io.grpc.ForwardingClientCallListener$SimpleForwardingClientCallListener.onClose(ForwardingClientCallListener.java:40)
        at com.google.api.gax.grpc.ChannelPool$ReleasingClientCall$1.onClose(ChannelPool.java:570)
        at io.grpc.internal.DelayedClientCall$DelayedListener$3.run(DelayedClientCall.java:489)
        at io.grpc.internal.DelayedClientCall$DelayedListener.delayOrExecute(DelayedClientCall.java:453)
        at io.grpc.internal.DelayedClientCall$DelayedListener.onClose(DelayedClientCall.java:486)
        at io.grpc.internal.ClientCallImpl.closeObserver(ClientCallImpl.java:574)
        at io.grpc.internal.ClientCallImpl.access$300(ClientCallImpl.java:72)
        at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInternal(ClientCallImpl.java:742)
        at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInContext(ClientCallImpl.java:723)
        at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
        at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
        at java.base/java.lang.Thread.run(Thread.java:833)
        Suppressed: java.lang.RuntimeException: Asynchronous task failed
            at com.google.api.gax.rpc.ServerStreamIterator.hasNext(ServerStreamIterator.java:105)
            at com.google.cloud.vertexai.generativeai.ResponseStreamIteratorWithHistory.hasNext(ResponseStreamIteratorWithHistory.java:37)
            at java.base/java.util.Iterator.forEachRemaining(Iterator.java:132)
            at java.base/java.util.Spliterators$IteratorSpliterator.forEachRemaining(Spliterators.java:1845)
            at java.base/java.util.stream.ReferencePipeline$Head.forEach(ReferencePipeline.java:762)
            at .....product.bll.usecase.AIProductPriceSuggestionUseCase.execute(AIProductPriceSuggestionUseCase.java:72)
            at .....product.api.ProductController.priceSuggestion(ProductController.java:176)
            at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
            at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
            at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
            at java.base/java.lang.reflect.Method.invoke(Method.java:568)
            at org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:207)
            at org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:152)
            at org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:118)
            at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:884)
            at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:797)
            at org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
            at org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:1081)
            at org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:974)
            at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:1011)
            at org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:914)
            at jakarta.servlet.http.HttpServlet.service(HttpServlet.java:590)
            at org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:885)
            at jakarta.servlet.http.HttpServlet.service(HttpServlet.java:658)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:205)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:51)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at .....application.api.web.filter.EmployeeActivityFilter.doFilter(EmployeeActivityFilter.java:57)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at .....application.api.web.filter.JwtRoleBasedFilter.filter(JwtRoleBasedFilter.java:44)
            at .....application.api.web.filter.AbstractJwtFilter.doFilter(AbstractJwtFilter.java:98)
            at .....commons.web.filter.DispatcherFilter.doFilter(DispatcherFilter.java:70)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at org.springframework.security.web.FilterChainProxy.lambda$doFilterInternal$3(FilterChainProxy.java:231)
            at org.springframework.security.web.ObservationFilterChainDecorator$FilterObservation$SimpleFilterObservation.lambda$wrap$1(ObservationFilterChainDecorator.java:479)
            at org.springframework.security.web.ObservationFilterChainDecorator$AroundFilterObservation$SimpleAroundFilterObservation.lambda$wrap$1(ObservationFilterChainDecorator.java:340)
            at org.springframework.security.web.ObservationFilterChainDecorator.lambda$wrapSecured$0(ObservationFilterChainDecorator.java:82)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:128)
            at org.springframework.security.web.access.intercept.AuthorizationFilter.doFilter(AuthorizationFilter.java:100)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.access.ExceptionTranslationFilter.doFilter(ExceptionTranslationFilter.java:126)
            at org.springframework.security.web.access.ExceptionTranslationFilter.doFilter(ExceptionTranslationFilter.java:120)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.authentication.AnonymousAuthenticationFilter.doFilter(AnonymousAuthenticationFilter.java:100)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter.doFilter(SecurityContextHolderAwareRequestFilter.java:179)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.savedrequest.RequestCacheAwareFilter.doFilter(RequestCacheAwareFilter.java:63)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.authentication.logout.LogoutFilter.doFilter(LogoutFilter.java:107)
            at org.springframework.security.web.authentication.logout.LogoutFilter.doFilter(LogoutFilter.java:93)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.web.filter.CorsFilter.doFilterInternal(CorsFilter.java:91)
            at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.header.HeaderWriterFilter.doHeadersAfter(HeaderWriterFilter.java:90)
            at org.springframework.security.web.header.HeaderWriterFilter.doFilterInternal(HeaderWriterFilter.java:75)
            at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.context.SecurityContextHolderFilter.doFilter(SecurityContextHolderFilter.java:82)
            at org.springframework.security.web.context.SecurityContextHolderFilter.doFilter(SecurityContextHolderFilter.java:69)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter.doFilterInternal(WebAsyncManagerIntegrationFilter.java:62)
            at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:227)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.session.DisableEncodeUrlFilter.doFilterInternal(DisableEncodeUrlFilter.java:42)
            at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.wrapFilter(ObservationFilterChainDecorator.java:240)
            at org.springframework.security.web.ObservationFilterChainDecorator$AroundFilterObservation$SimpleAroundFilterObservation.lambda$wrap$0(ObservationFilterChainDecorator.java:323)
            at org.springframework.security.web.ObservationFilterChainDecorator$ObservationFilter.doFilter(ObservationFilterChainDecorator.java:224)
            at org.springframework.security.web.ObservationFilterChainDecorator$VirtualFilterChain.doFilter(ObservationFilterChainDecorator.java:137)
            at org.springframework.security.web.FilterChainProxy.doFilterInternal(FilterChainProxy.java:233)
            at org.springframework.security.web.FilterChainProxy.doFilter(FilterChainProxy.java:191)
            at org.springframework.web.filter.DelegatingFilterProxy.invokeDelegate(DelegatingFilterProxy.java:352)
            at org.springframework.web.filter.DelegatingFilterProxy.doFilter(DelegatingFilterProxy.java:268)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:100)
            at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at org.springframework.web.filter.FormContentFilter.doFilterInternal(FormContentFilter.java:93)
            at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at org.springframework.web.filter.ServerHttpObservationFilter.doFilterInternal(ServerHttpObservationFilter.java:109)
            at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:201)
            at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at .....application.api.web.filter.LogRequestRejectedExceptionFilter.doFilter(LogRequestRejectedExceptionFilter.java:27)
            at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174)
            at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149)
            at org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:166)
            at org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:90)
            at org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:482)
            at org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:115)
            at org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:93)
            at org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:74)
            at org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:341)
            at org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:391)
            at org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:63)
            at org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:894)
            at org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1741)
            at org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:52)
            at org.apache.tomcat.util.threads.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1191)
            at org.apache.tomcat.util.threads.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:659)
            at org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
            ... 1 common frames omitted
    Caused by: io.grpc.StatusRuntimeException: UNAUTHENTICATED: Request had invalid authentication credentials. Expected OAuth 2 access token, login cookie or other valid authentication credential. See https://developers.google.com/identity/sign-in/web/devconsole-project.
        at io.grpc.Status.asRuntimeException(Status.java:533)
        ... 17 common frames omitted

To be honest I have no idea what I'm doing wrong. Could somebody help me to find out please?


Solution

  • It seems that after a while I found the problem/solution. I did not wanted to delete the question while it might be useful for others as well.

    @Lazy
    @Bean
    GoogleCredentials googleCredentials() throws IOException {
       return GoogleCredentials.fromStream(new ClassPathResource(this.config.getCredential().getPath()).getInputStream())
        .createScoped("https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/cloud-platform.read-only");
     }
    

    So I had to register the above scopes, now everything is working fine.

    Thanks for the help which I got.