Search code examples
kotlinspring-cloudspring-cloud-gateway

How to write custom GlobalFilter for checking request body in Spring Cloud Gateway?


I would like to validate body in GlobalFilter.

I need to read two http headers that contain checksum of body and compare it with body itself:

internal class MyFilter : GlobalFilter {

    override fun filter(exchange: ServerWebExchange, chain: GatewayFilterChain) =
        ByteArrayDecoder()
            .decodeToMono(
                exchange.request.body,
                ResolvableType.forClass(ByteBuffer::class.java),
                exchange.request.headers.contentType,
                null
            )
            .flatMap { /* my logic checking body against request headers */ chain.filter(exchange) }
}

The problem is that decodingToMono stucks and does not forward requests.

How I can decode body properly?


Solution

  • I've managed to write a filter that does not stuck after reading body:

    interface BodyFilter {
        fun filter(
            body: Mono<ByteArrayResource>,
            exchange: ServerWebExchange,
            passRequestFunction: () -> Mono<Void>
        ): Mono<Void>
    }
    
    class HeaderAndBodyGlobalFilter(private val bodyFilter: BodyFilter) : GlobalFilter {
    
        private val messageReaders: List<HttpMessageReader<*>> = HandlerStrategies.withDefaults().messageReaders()
    
        override fun filter(exchange: ServerWebExchange, chain: GatewayFilterChain): Mono<Void> {
            val serverRequest: ServerRequest = ServerRequest.create(exchange, messageReaders)
            val body: Mono<ByteArrayResource> = serverRequest.bodyToMono<ByteArrayResource>(ByteArrayResource::class.java)
            return bodyFilter.filter(body, exchange) { reconstructRequest(body, exchange, chain) }
        }
    
        private fun reconstructRequest(
            body: Mono<ByteArrayResource>,
            exchange: ServerWebExchange,
            chain: GatewayFilterChain
        ): Mono<Void> {
            val headers: HttpHeaders = writableHttpHeaders(exchange.request.headers)
            val outputMessage = CachedBodyOutputMessage(exchange, headers)
    
            return BodyInserters.fromPublisher(
                body,
                ByteArrayResource::class.java
            ).insert(outputMessage, BodyInserterContext())
                .then(Mono.defer {
                    val decorator: ServerHttpRequestDecorator = decorate(
                        exchange, headers, outputMessage
                    )
                    chain
                        .filter(exchange.mutate().request(decorator).build())
                })
        }
    
        private fun decorate(
            exchange: ServerWebExchange,
            headers: HttpHeaders,
            outputMessage: CachedBodyOutputMessage
        ): ServerHttpRequestDecorator {
            return object : ServerHttpRequestDecorator(exchange.request) {
                override fun getHeaders(): HttpHeaders {
                    val contentLength = headers.contentLength
                    val httpHeaders = HttpHeaders()
                    httpHeaders.putAll(super.getHeaders())
                    if (contentLength > 0) {
                        httpHeaders.contentLength = contentLength
                    } else {
                        // TODO: this causes a 'HTTP/1.1 411 Length Required' // on
                        // httpbin.org
                        httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked")
                    }
                    return httpHeaders
                }
    
                override fun getBody(): Flux<DataBuffer> {
                    return outputMessage.body
                }
            }
        }
    }
    

    Then implementation of BodyFilter either returns Mono.empty() on failure or calls passRequestFunction on success.