Search code examples
spring-security

Efficient Token Refresh with OAuth2 Client Credentials Flow in Spring Boot 3.4 using RestClient


I’m currently using Spring Boot 3.4 and Java 21 and trying to integrate RestClient with OAuth2 client credentials flow. I’ve come across a tricky scenario and would appreciate any guidance:

Problem Description

I have a bearer token generation URL and a protected API endpoint. I use RestClient with client credentials authorization-grant-type, supplying my client ID and secret. It works perfectly for a single session. However, when someone regenerates the token (via a portal), the old token becomes invalid immediately, and I encounter issues.

Current Challenge

My app keeps using the old token to call APIs, which results in 401 Unauthorized errors. To mitigate this, I tried using WebClient with an exchange filter function to detect 401 errors, fetch a new token, and retry the API call. The retry works for the first call but does not persist the new token for subsequent API calls. The app ends up calling the token generation endpoint every time, which is inefficient.

I have tried with the RestClient before and it dosen't work and I have gone back to WebClient Still the issue exist.

Below code belongs to WebClient.


import lombok.extern.slf4j.Slf4j;
import org.json.JSONObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

import java.util.Base64;
import java.util.Objects;

@Slf4j
@Configuration
public class Oauth2WebClientConfig {

   private static final String TRACE_ID = "TRACE_ID";
   private final Environment env;

   @Autowired
   public Oauth2WebClientConfig(Environment env) {
       this.env = env;
   }
   // == Oauth2 Configuration ==

   // == Oauth2 Configuration ==
   @Bean
   ReactiveClientRegistrationRepository clientRegistration() {
       ClientRegistration clientRegistration = ClientRegistration
               .withRegistrationId("custom")
               .tokenUri(env.getProperty("token-uri"))
               .clientId(env.getProperty("client-id"))
               .clientSecret(env.getProperty("client-secret"))
               .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
               .scope(env.getProperty("scope"))
               .build();
       return new InMemoryReactiveClientRegistrationRepository(clientRegistration);
   }

   @Bean
   ReactiveOAuth2AuthorizedClientService authorizedClientService() {
       return new InMemoryReactiveOAuth2AuthorizedClientService(clientRegistration());
   }

   // == WebFlux Configuration ==
   @Bean
   WebClient webClient(ReactiveClientRegistrationRepository clientRegistration, ReactiveOAuth2AuthorizedClientService authorizedClientService) {
       ServerOAuth2AuthorizedClientExchangeFilterFunction oauth = new ServerOAuth2AuthorizedClientExchangeFilterFunction(
               new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(clientRegistration, authorizedClientService));
       oauth.setDefaultClientRegistrationId("custom");

       // @formatter:off
       return WebClient.builder()
               .filter(oauth)
               .filters(exchangeFilterFunctions -> {
                   exchangeFilterFunctions.add(renewTokenFilter());
                   exchangeFilterFunctions.add(logRequest());
                   exchangeFilterFunctions.add(logResponse());
               })
               .build();
       // @formatter:on
   }

   // == Renew Token if expired filter ==
   private ExchangeFilterFunction renewTokenFilter() {
       return (request, next) -> next.exchange(request).flatMap(response -> {
           if (response.statusCode().value() == HttpStatus.UNAUTHORIZED.value()) {
               // @formatter:off
               return response.releaseBody()
                       .then(WebClient.create().post()
                               .uri(Objects.requireNonNull(env.getProperty("token-uri")))
                               .header(HttpHeaders.AUTHORIZATION, "Basic " + Base64.getEncoder().encodeToString((env.getProperty("client-id") + ":" + env.getProperty("client-secret")).getBytes()))
                               .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
                               .body(BodyInserters.fromFormData(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()))
                               .retrieve()
                               .bodyToMono(String.class))
                       .flatMap(token -> {
                           JSONObject tokenResponse = new JSONObject(token);
                           log.info("new token : {} TRACE_ID : {}", tokenResponse.getString("access_token"), Objects.requireNonNull(request.headers().get(TRACE_ID)).getFirst());
                           ClientRequest newRequest = ClientRequest
                                   .from(request)
                                   .headers(headers -> headers.remove(HttpHeaders.AUTHORIZATION))
                                   .header(HttpHeaders.AUTHORIZATION, "Bearer " + tokenResponse.getString("access_token"))
                                   .build();

                           return next.exchange(newRequest);
                       });
               // @formatter:on
           } else {
               return Mono.just(response);
           }
       });
   }

   // == Log Request ==
   private ExchangeFilterFunction logRequest() {
       return ExchangeFilterFunction.ofRequestProcessor(clientRequest -> {
           StringBuilder sb = new StringBuilder("Request: \n")
                   .append("Method : ")
                   .append(clientRequest.method())
                   .append(" ")
                   .append("Headers : ")
                   .append(clientRequest.headers())
                   .append(" ")
                   .append("URL : ")
                   .append(clientRequest.url());
           clientRequest
                   .headers()
                   .forEach((name, values) -> values.forEach(value -> sb
                           .append("\n")
                           .append(name)
                           .append(":")
                           .append(value)));
           log.info(sb.toString());

           return Mono.just(clientRequest);
       });
   }

   // == Log Response ==
   private ExchangeFilterFunction logResponse() {
       return ExchangeFilterFunction.ofResponseProcessor(clientResponse -> {
           StringBuilder sb = new StringBuilder("Response: \n")
                   .append("StatusCode: ")
                   .append(clientResponse.statusCode().value())
                   .append(" ")
                   .append("Headers : ")
                   .append(clientResponse.headers());
           clientResponse
                   .headers()
                   .asHttpHeaders()
                   .forEach((key, value1) -> value1.forEach(value -> sb
                           .append("\n")
                           .append(key)
                           .append(":")
                           .append(value)));
           log.info(sb.toString());
           return Mono.just(clientResponse);
       });
   }
   // == WebFlux Configuration ==
}

Questions

  1. Is it possible to persist the refreshed token efficiently using either RestClient or WebClient, so the app doesn’t repeatedly fetch the token (I prefer RestClient)?
  2. Are there any Spring best practices or patterns to handle this scenario?

Solution

  • You should configure your RestClient bean with an OAuth2AuthorizationFailureHandler.

    Samples in the Spring doc linked above and in this other answer (the subject of which is proxy configuration, but all RestClient beans in this answer are configured with an OAuth2AuthorizationFailureHandler, including those auto-configured by "my" starter).