Search code examples
springspring-bootoauth-2.0spring-webflux

How to modify body of the spring boot reactive oauth2 authorization request


When I use WebClient defined below, authorization request is created and sent.

How can I modify authorization POST request, which is sent by spring boot to get the bearer token? I need to add some fields to the request body.

Client and Provider Configurations

spring.security.oauth2.client.registration.d365.authorization-grant-type=client_credentials
spring.security.oauth2.client.registration.d365.client-id=my-client-id
spring.security.oauth2.client.registration.d365.client-secret=my-secret
spring.security.oauth2.client.provider.d365.token-uri=http://localhost:8085/oauth/token

WebClient configuration

@Bean
WebClient webClient(ReactiveClientRegistrationRepository clientRegistrations) {
    ServerOAuth2AuthorizedClientExchangeFilterFunction oauth =
      new ServerOAuth2AuthorizedClientExchangeFilterFunction(
        clientRegistrations,
        new UnAuthenticatedServerOAuth2AuthorizedClientRepository());
    oauth.setDefaultClientRegistrationId("d365");
    return WebClient.builder()
      .filter(oauth)
      .build();
}

Solution

  • you can extract/read/modify/manipulate the request and response and their headers with the help of ServerHttpRequestDecorator & ServerHttpResponseDecorator see below

    Note: i implemented GatewayFilter because i have this logic in gateway service level, if you want to modify at micro service level you can use WebFilter

    import lombok.extern.log4j.Log4j2;
    import org.apache.commons.io.IOUtils;
    import org.reactivestreams.Publisher;
    import org.springframework.cloud.gateway.filter.GatewayFilter;
    import org.springframework.cloud.gateway.filter.GatewayFilterChain;
    import org.springframework.context.annotation.Configuration;
    import org.springframework.core.Ordered;
    import org.springframework.core.io.buffer.DataBuffer;
    import org.springframework.core.io.buffer.DataBufferFactory;
    import org.springframework.core.io.buffer.DefaultDataBuffer;
    import org.springframework.core.io.buffer.DefaultDataBufferFactory;
    import org.springframework.http.server.reactive.ServerHttpRequest;
    import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
    import org.springframework.http.server.reactive.ServerHttpResponse;
    import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
    import org.springframework.web.server.ServerWebExchange;
    import reactor.core.publisher.Flux;
    import reactor.core.publisher.Mono;
    import reactor.core.scheduler.Schedulers;
    
    import java.io.ByteArrayOutputStream;
    import java.nio.channels.Channels;
    import java.nio.charset.StandardCharsets;
    
    @Configuration
    @Log4j2
    public class RequestResponseModifyFilter implements GatewayFilter/WebFilter, Ordered {
    
    
        @Override
        public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
    
            String path = exchange.getRequest().getPath().toString();
            ServerHttpResponse response = exchange.getResponse();
            ServerHttpRequest request = exchange.getRequest();
            DataBufferFactory dataBufferFactory = response.bufferFactory();
    
            // log the request body
            ServerHttpRequest decoratedRequest = getDecoratedRequest(request);
            // log the response body
            ServerHttpResponseDecorator decoratedResponse = getDecoratedResponse(path, response, request, dataBufferFactory);
            return chain.filter(exchange.mutate().request(decoratedRequest).response(decoratedResponse).build());
        }
    
        private ServerHttpResponseDecorator getDecoratedResponse(String path, ServerHttpResponse response, ServerHttpRequest request, DataBufferFactory dataBufferFactory) {
            return new ServerHttpResponseDecorator(response) {
    
                @Override
                public Mono<Void> writeWith(final Publisher<? extends DataBuffer> body) {
    
                    if (body instanceof Flux) {
    
                        Flux<? extends DataBuffer> fluxBody = (Flux<? extends DataBuffer>) body;
    
                        return super.writeWith(fluxBody.buffer().map(dataBuffers -> {
    
                            DefaultDataBuffer joinedBuffers = new DefaultDataBufferFactory().join(dataBuffers);
                            byte[] content = new byte[joinedBuffers.readableByteCount()];
                            joinedBuffers.read(content);
                             String responseBody = new String(content, StandardCharsets.UTF_8);//MODIFY RESPONSE and Return the Modified response
                            log.debug("requestId: {}, method: {}, url: {}, \nresponse body :{}", request.getId(), request.getMethodValue(), request.getURI(), responseBody);
    
                            return dataBufferFactory.wrap(responseBody.getBytes());
                        })).onErrorResume(err -> {
    
                            log.error("error while decorating Response: {}",err.getMessage());
                            return Mono.empty();
                        });
    
                    }
                    return super.writeWith(body);
                }
            };
        }
    
        private ServerHttpRequest getDecoratedRequest(ServerHttpRequest request) {
    
            return new ServerHttpRequestDecorator(request) {
                @Override
                public Flux<DataBuffer> getBody() {
    
                    log.debug("requestId: {}, method: {} , url: {}", request.getId(), request.getMethodValue(), request.getURI());
                    return super.getBody().publishOn(Schedulers.boundedElastic()).doOnNext(dataBuffer -> {
    
                        try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) {
    
                            Channels.newChannel(byteArrayOutputStream).write(dataBuffer.asByteBuffer().asReadOnlyBuffer());
                            String requestBody = IOUtils.toString(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8.toString());//MODIFY REQUEST and Return the Modified request
                            log.debug("for requestId: {}, request body :{}", request.getId(), requestBody);
                        } catch (Exception e) {
                            log.error(e.getMessage());
                        }
                    });
                }
            };
        }
    
        @Override
        public int getOrder() { return -2;}
    }