Search code examples
javajacksonspring-webfluxproject-reactorspring-webclient

Subscribing to WebClient Flux<DataBuffer> response for async JSON reading not working


I am having issues with some custom handling of a WebClient response. I am trying to make it play nice with the asynchronous reading of JSON using Jackson.

I have verified that this works very well if I replace the source of the Flux<DataBuffer> with a simpler:

final var dataBufferFlux = DataBufferUtils.readInputStream(
  () ->  new ByteArrayInputStream(bytes), DefaultDataBufferFactory.sharedInstance, 32
);

But once the source of the Flux is from the webclient (as per the example test case below), it stops working. The subscription.request(1) becomes a no-op, leaving my onNext never to be called.

Any ideas? Is the WebClient response already consumed?

Though I have managed to do some ugly testing rewrites with Flux#blockFirst and such where I have gotten some results back (before it crashes because I block the reactor thread), so there is content inside the Flux.

class StreamedParsingTest {

  @Test
  @SneakyThrows
  void testStreamedReading() {

    final var objectMapper = new ObjectMapper();
    final var parser = (NonBlockingByteBufferJsonParser) objectMapper.getFactory().createNonBlockingByteBufferParser();
    final var feeder = (ByteBufferFeeder) parser.getNonBlockingInputFeeder();
    final var allowedToRequest = new AtomicBoolean(true);
    final var subscriptionRef = new AtomicReference<Subscription>();

    final var webClient = createWebClient(Duration.ofSeconds(10), Duration.ofSeconds(10));

    wm.stubFor(
        WireMock.post(WireMock.urlPathMatching(".*")).willReturn(
            WireMock.aResponse()
                .withStatus(200)
                .withBody(bytes)
                .withHeader("Content-Type", "application/json")
        ));

    final var containerMono = webClient.post()
        .uri("http://localhost:" + wm.getPort())
        .retrieve()
        .toEntityFlux(DataBuffer.class)
        .map(responseEntity -> new ContainerClass(responseEntity.getBody(), responseEntity.getStatusCode().value()));

    final var blockedValue = containerMono.flatMap(cc -> {

          cc.getDataBufferFlux()
              .subscribeOn(Schedulers.boundedElastic())
              .subscribe(new Subscriber<>() {
                @Override
                public void onSubscribe(Subscription s) {
                  subscriptionRef.set(s);
                }

                @Override
                public void onNext(DataBuffer dataBuffer) {
                  try {
                    if (feeder.needMoreInput()) {
                      feeder.feedInput(dataBuffer.asByteBuffer());
                      allowedToRequest.lazySet(true);
                    }
                  } catch (IOException e) {
                    throw new RuntimeException(e);
                  }
                }

                @Override
                public void onError(Throwable t) {
                  feeder.endOfInput();
                }

                @Override
                public void onComplete() {
                  feeder.endOfInput();
                }
              });

          String foundValue = null;
          try {
            JsonToken token;
            boolean foundField = false;
            do {
              while ((token = parser.nextToken()) == JsonToken.NOT_AVAILABLE) {
                if (subscriptionRef.get() != null) {
                  if (allowedToRequest.compareAndSet(true, false)) {
                    subscriptionRef.get().request(1);
                  }
                }
              }

              if (foundField && token.isScalarValue()) {
                foundValue = parser.getText();
                subscriptionRef.get().cancel();
                break;
              } else if (token == JsonToken.FIELD_NAME && StringUtils.equals(parser.getText(), "Field2")) {
                foundField = true;
              }
            } while (token != null);
          } catch (IOException ex) {
            return Mono.error(new RuntimeException(ex));
          }

          return Mono.justOrEmpty(foundValue);
        })
        .block();

    Assertions.assertEquals("bar", blockedValue);
  }

  private static final byte[] bytes = """
      {
        "Object": {
          "Field1": "foo",
          "Field2": "bar",
          "Field3": "xyz
        }
      }
      """.getBytes(StandardCharsets.UTF_8);

  @RegisterExtension
  static WireMockExtension wm = WireMockExtension.newInstance()
      .options(
          WireMockConfiguration.wireMockConfig()
              .dynamicPort()
              .notifier(new Slf4jNotifier(true))
              .asynchronousResponseEnabled(true)
              .asynchronousResponseThreads(100)
      )
      .build();

  private static WebClient createWebClient(Duration connectTimeout, Duration readTimeout) {

    var connector = new ReactorClientHttpConnector(
        HttpClient
            .create(ConnectionProvider.create("Test-Connection"))
            .tcpConfiguration(tcpClient -> tcpClient
                .metrics(true)
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis())
                .doOnConnected(conn -> conn
                    .addHandlerLast(new ReadTimeoutHandler(readTimeout.toMillis(), TimeUnit.MILLISECONDS))
                )
            )
    );

    return WebClient.builder()
        .clientConnector(connector)
        .codecs(clientCodecConfigurer -> clientCodecConfigurer.defaultCodecs().maxInMemorySize(64))
        .build();
  }

  @Value
  private static class ContainerClass {

    Flux<DataBuffer> dataBufferFlux;
    int status;
  }
}

Worth noting that this uses the latest non-final version of Jackson that has support for the non-blocking ByteBuffer feeder. But if you need to run the test case, you could replace those parts with pure bytes instead.


Solution

  • I found a solution to this, by going a similar but different direction.

    By creating multithreaded Flux sinks I can in a manner of sorts stream the content of the Jackson JSON in a non-blocking and low-memory mode.

    For the two scenarios of writing and reading JSON in a streaming way for a reactive flow, I created these two classes:

    ByteBufferJsonPublisher:

    public class ByteBufferJsonPublisher {
    
      private static final DataBufferFactory BUFFER_FACTORY = new NettyDataBufferFactory(
          ByteBufAllocator.DEFAULT
      );
    
      private static final Executor PUBLISH_EXECUTOR = Executors.newCachedThreadPool(
          new ThreadFactoryBuilder()
              .setNameFormat("ByteBufferJsonPublisher-%d")
              .setDaemon(true)
              .build()
      );
    
      private final AtomicBoolean isSubscriptionComplete = new AtomicBoolean(false);
      private final ObjectMapper objectMapper;
    
      public ByteBufferJsonPublisher(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
      }
    
      public static Flux<DataBuffer> publish(ObjectMapper objectMapper, GeneratorConsumer generatorCallback) {
        return new ByteBufferJsonPublisher(objectMapper).publishInternal(generatorCallback);
      }
    
      /**
       * This publishing method reuses data buffers, so you MUST NOT buffer those data buffers for later re-use. They WILL have their contents
       * changed regularly.
       */
      private Flux<DataBuffer> publishInternal(GeneratorConsumer generatorCallback) {
    
        return Flux.create(
            sink -> {
              sink.onCancel(() -> {
                isSubscriptionComplete.set(true);
                log.debug("Json Publishing was cancelled");
              });
              PUBLISH_EXECUTOR.execute(() -> executePublishingSync(sink, generatorCallback));
            },
    
            // The buffer strategy is BUFFER, so it is *technically* possible we could use too much memory here on backpressure issues.
            // But we try to take care listening to the number of requests made.
            OverflowStrategy.BUFFER
        );
      }
    
      private void executePublishingSync(FluxSink<DataBuffer> sink, GeneratorConsumer generatorCallback) {
    
        final var os = new DelegatingBlockableOutputStream(sink);
    
        try (os) {
          try (final var generator = objectMapper.createGenerator(os)) {
            generatorCallback.accept(generator);
          } catch (SignalCompletionException ex) {
            log.debug("Aborted publisher, since completion exception signal was thrown: %s".formatted(ex));
            sink.complete();
            return;
          } catch (Exception ex) {
            sink.error(ex);
            return;
          }
    
          sink.complete();
    
        } catch (IOException ex) {
          log.warn("Could not create or close the delegating output stream", ex);
        }
      }
    
      private void ensureActiveOrThrowExceptionSignal() {
        if (isSubscriptionComplete.get()) {
          throw new SignalCompletionException();
        }
      }
    
      private static class SignalCompletionException extends RuntimeException {
    
      }
    
      private class DelegatingBlockableOutputStream extends OutputStream {
    
        private final FluxSink<DataBuffer> sink;
    
        public DelegatingBlockableOutputStream(FluxSink<DataBuffer> sink) {
          this.sink = sink;
        }
    
        @Override
        public void write(int b) {
          write(new byte[]{(byte) b}, 0, 1);
        }
    
        @Override
        @SneakyThrows
        public void write(byte[] b, int off, int len) {
    
          DataBuffer buffer = null;
          while (len > 0) {
            ensureActiveOrThrowExceptionSignal();
    
            if (buffer == null) {
    
              // If we in a bounded demand, then we should try to wait for downstream to catch up.
              // But if we wait the max amount of time (which should be short), then move on.
    
              // But if none available, then we simply allocate a new one, because we have to.
              buffer = BUFFER_FACTORY.allocateBuffer(StreamingUtils.getMaxBytesInMemory()).writePosition(0);
            }
    
            final var writeLength = Math.min(len, buffer.writableByteCount());
            if (writeLength <= 0) {
              throw new IllegalArgumentException("This should never be 0 or less");
            }
    
            buffer.write(b, off, writeLength);
            if (buffer.writableByteCount() == 0) {
    
              sink.next(buffer);
              buffer = null;
            }
    
            // Then step forward with that amount in the current byte array.
            // There might be more left inside it to write to a next byte buffer.
            off += writeLength;
            len -= writeLength;
          }
    
          ensureActiveOrThrowExceptionSignal();
          if (buffer != null && buffer.writePosition() > 0) {
            sink.next(buffer);
          }
        }
      }
    }
    

    and ByteBufferJsonConsumer:

    public class ByteBufferJsonConsumer implements AutoCloseable {
    
      private final NonBlockingByteBufferJsonParser parser;
    
      List<String> remainingKeys;
      Map<String, Object> map;
      String picked;
    
      public ByteBufferJsonConsumer(JsonFactory jsonFactory, String... keys) {
    
        try {
          this.parser = (NonBlockingByteBufferJsonParser) jsonFactory.createNonBlockingByteBufferParser();
        } catch (IOException ex) {
          throw new CaughtRuntimeIOException("Could not create the non-blocking parser", ex);
        }
    
        this.remainingKeys = new ArrayList<>(Arrays.asList(keys));
        this.map = new HashMap<>();
      }
    
      public void picks(String key) {
        this.picked = key;
      }
    
      public boolean expectsEatingWhatWasPicked() {
        return this.picked != null;
      }
    
      public void eats(Object value) {
    
        if (this.picked == null) {
          throw new IllegalArgumentException("There collector has not been marked as being at a field");
        }
    
        if (!this.remainingKeys.remove(this.picked)) {
          throw new IllegalArgumentException("Key '%s' was never registered or was already added".formatted(this.picked));
        }
    
        this.map.put(this.picked, value);
        this.picked = null;
      }
    
      public boolean wants(String key) {
        return this.remainingKeys.contains(key);
      }
    
      public Object get(String key) {
        return this.map.get(key);
      }
    
      public Map<String, Object> getMap() {
        return this.map;
      }
    
      public String picked() {
        return this.picked;
      }
    
      public boolean wantsMore() {
        return !this.remainingKeys.isEmpty();
      }
    
      public void starve() {
        this.remainingKeys.clear();
      }
    
      public NonBlockingByteBufferJsonParser getParser() {
        return parser;
      }
    
      @Override
      public void close() throws Exception {
        this.parser.close();
      }
    
      /**
       * This will consume the underlying stream. If you call this method twice, you might not get the same result.
       * <p>
       * WARNING! This can currently only fetch from FIELD NAME, and does not know of nested objects! This is because the otherwise used
       * FilteringParserDelegate does not support non-blocking async filtering! So to make things easier we only get tokens and find the field
       * name and that's that.
       */
      public static Mono<Map<String, Object>> collectJsonFields(
          Flux<DataBuffer> dataBufferFlux,
          JsonFactory jsonFactory,
          String[] wants
      ) {
    
        return dataBufferFlux
            .scan(new ByteBufferJsonConsumer(jsonFactory, wants), ByteBufferJsonConsumer::feedBufferToJsonConsumer)
            .takeUntil(consumer -> !consumer.wantsMore())
            .last()
            .doOnEach(signal -> {
    
              try {
                var consumer = signal.get();
                if (consumer != null) {
                  // The consumer should never be null, since on exception we just starve the consumer.
                  consumer.close();
                }
              } catch (Exception ex) {
                log.error("Could not close the consumer. Might resolve into a small temporary memory leak if used abundantly", ex);
              }
            })
            .flatMap(consumer -> Mono.just(consumer.getMap()));
      }
    
      private static ByteBufferJsonConsumer feedBufferToJsonConsumer(ByteBufferJsonConsumer consumer, DataBuffer buffer) {
    
        var parser = consumer.getParser();
    
        JsonToken token;
        try {
    
          if (parser.needMoreInput()) {
            parser.feedInput(buffer.asByteBuffer());
            DataBufferUtils.release(buffer);
          }
    
          do {
            if ((token = parser.nextToken()) == JsonToken.NOT_AVAILABLE) {
              break;
            }
    
            pickOrEatOrSkipToken(consumer, parser);
          } while (token != null && consumer.wantsMore());
    
          return consumer;
    
        } catch (IllegalArgumentException | IOException ex) {
          log.error("Could not feed or consume the data buffer. Will starve the consumer and return what we have so far", ex);
          consumer.starve();
    
          return consumer;
        }
      }
    
      private static void pickOrEatOrSkipToken(ByteBufferJsonConsumer consumer, NonBlockingByteBufferJsonParser parser) throws IOException {
    
        if (consumer.expectsEatingWhatWasPicked()) {
          if (parser.currentToken().isScalarValue()) {
            consumer.eats(JsonParserUtils.getCurrentTokenValue(parser));
          } else {
            var message = "Fed consumer something it could not eat, '%s': '%s'".formatted(consumer.picked(), parser.currentToken());
            throw new IllegalArgumentException(message);
          }
        } else if (parser.currentToken() == JsonToken.FIELD_NAME && consumer.wants(parser.getText())) {
          consumer.picks(parser.getText());
        } else {
    
          // We do not care.
        }
      }
    }
    

    They work great for my situation. Fast, low-memory, can be cancalled, handles errors well, can deal pretty well with downstream request demands.