Search code examples
springspring-bootspring-cachespring-cloud-gatewaycaffeine-cache

Spring cloud gateway with Spring cache and caffeine


I have a spring cloud gateway which forwards the API rest requests to some microservices.

I would like to cache the response for specific requests. For this reason I wrote this Filter

@Component
@Slf4j
public class CacheResponseGatewayFilterFactory extends AbstractGatewayFilterFactory<CacheResponseGatewayFilterFactory.Config> {
   private final CacheManager cacheManager;

   public CacheResponseGatewayFilterFactory(CacheManager cacheManager) {
     super(CacheResponseGatewayFilterFactory.Config.class);
     this.cacheManager = cacheManager;
   }

   @Override
   public GatewayFilter apply(CacheResponseGatewayFilterFactory.Config config) {
     final var cache = cacheManager.getCache("MyCache");
     return (exchange, chain) -> {
        final var path = exchange.getRequest().getPath();
        if (nonNull(cache.get(path))) {
            log.info("Return cached response for request: {}", path);
            final var response = cache.get(path, ServerHttpResponse.class);
            final var mutatedExchange = exchange.mutate().response(response).build();
            return mutatedExchange.getResponse().setComplete();
        }

        return chain.filter(exchange).doOnSuccess(aVoid -> {
            cache.put(path, exchange.getResponse());
        });
    };
}

When I call my rest endpoint, the first time I receive the right json, the second time I got an empty body.

What am I doing wrong?

EDIT This is a screenshot of the exchange.getRequest() just before doing cache.put()

enter image description here


Solution

  • I solved it creating a GlobalFilter and a ServerHttpResponseDecorator. This code is caching all the responses regardless (it can be easily improved to cache only specific responses).

    This is the code. However I think it can be improved. In case let me know.

    @Slf4j
    @Component
    public class CacheFilter implements GlobalFilter, Ordered {
      private final CacheManager cacheManager;
    
      public CacheFilter(CacheManager cacheManager) {
        this.cacheManager = cacheManager;
      }
    
      @Override
      public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        final var cache = cacheManager.getCache("MyCache");
    
        final var cachedRequest = getCachedRequest(exchange.getRequest());
        if (nonNull(cache.get(cachedRequest))) {
            log.info("Return cached response for request: {}", cachedRequest);
            final var cachedResponse = cache.get(cachedRequest, CachedResponse.class);
    
            final var serverHttpResponse = exchange.getResponse();
            serverHttpResponse.setStatusCode(cachedResponse.httpStatus);
            serverHttpResponse.getHeaders().addAll(cachedResponse.headers);
            final var buffer = exchange.getResponse().bufferFactory().wrap(cachedResponse.body);
            return exchange.getResponse().writeWith(Flux.just(buffer));
        }
    
        final var mutatedHttpResponse = getServerHttpResponse(exchange, cache, cachedRequest);
        return chain.filter(exchange.mutate().response(mutatedHttpResponse).build());
      }
    
      private ServerHttpResponse getServerHttpResponse(ServerWebExchange exchange, Cache cache, CachedRequest cachedRequest) {
        final var originalResponse = exchange.getResponse();
        final var dataBufferFactory = originalResponse.bufferFactory();
    
        return new ServerHttpResponseDecorator(originalResponse) {
    
            @NonNull
            @Override
            public Mono<Void> writeWith(@NonNull Publisher<? extends DataBuffer> body) {
                if (body instanceof Flux) {
                    final var flux = (Flux<? extends DataBuffer>) body;
                    return super.writeWith(flux.buffer().map(dataBuffers -> {
                        final var outputStream = new ByteArrayOutputStream();
                        dataBuffers.forEach(dataBuffer -> {
                            final var responseContent = new byte[dataBuffer.readableByteCount()];
                            dataBuffer.read(responseContent);
                            try {
                                outputStream.write(responseContent);
                            } catch (IOException e) {
                                throw new RuntimeException("Error while reading response stream", e);
                            }
                        });
                        if (Objects.requireNonNull(getStatusCode()).is2xxSuccessful()) {
                            final var cachedResponse = new CachedResponse(getStatusCode(), getHeaders(), outputStream.toByteArray());
                            log.debug("Request {} Cached response {}", cacheKey.getPath(), new String(cachedResponse.getBody(), UTF_8));
                            cache.put(cacheKey, cachedResponse);
                        }
                        return dataBufferFactory.wrap(outputStream.toByteArray());
                    }));
                }
                return super.writeWith(body);
            }
        };
      }
    
      @Override
      public int getOrder() {
        return -2;
      }
    
      private CachedRequest getCachedRequest(ServerHttpRequest request) {
        return CachedRequest.builder()
                .method(request.getMethod())
                .path(request.getPath())
                .queryParams(request.getQueryParams())
                .build();
      }
    
      @Value
      @Builder
      private static class CachedRequest {
        RequestPath path;
        HttpMethod method;
        MultiValueMap<String, String> queryParams;
    
      }
    
      @Value
      private static class CachedResponse {
        HttpStatus httpStatus;
        HttpHeaders headers;
        byte[] body;
      }
    }