Search code examples
javaspring-bootspring-webfluxproject-reactormicrometer-tracing

How to manually create and inject TraceContext into Reactor Context in Spring Webflux?


Background: I have a service running on Spring Boot 3.1.0 which communicates via REST and AMQP. After it is invoked via REST, it publishes the REST payload's content to a RabbitMQ Queue using reactor-rabbitmq and immediately returns a HTTP response.

@RestController
@RequiredArgsConstructor
@RequestMapping("${api.baseurl}")
public class CalculationInitiateController {

    private final RequestMapper requestMapper;
    private final ResponseMapper responseMapper;
    private final CalculationInitializer initializer;

    @ResponseStatus(code = HttpStatus.CREATED)
    @PostMapping("/initiate")
    public Mono<CalculationInitResponseDto> initiateCalculation(@RequestBody @Valid CalculationInitiationRequestDto request) {
        return Mono.just(requestMapper.map(request))
                .flatMap(initializer::initiateCalculation)
                .map(responseMapper::map);
    }
}

Event publishing via reactor-rabbitmq:

@Slf4j
@AllArgsConstructor
@Service
public class CalculationInitializer {

    private final Sender sender;
    private final ObjectMapper objectMapper;
    private final EventPublisherProperties eventPublisherProperties;
    private final OutboundMessageFactory messageFactory;

    public Mono<Boolean> initialize(CalculationEvent calculationEvent) {
        log.info("Publishing internal {} event for calculation: {}", eventType, calculationEvent.request().calculationId());
        var bytePayload = objectMapper.writeValueAsBytes(calculationEvent);
        var outboundMessage = new OutboundMessage("", eventConfig.getRoutingKey(), messageProperties(eventConfig), bytePayload);
        return send(outboundMessage);    
    }

    private AMQP.BasicProperties messageProperties(EventPublisherProperties.EventConfig eventConfig) {
        var context = tracer.currentTraceContext().context();
        return new AMQP.BasicProperties.Builder()
                .correlationId(context.traceId())
                .headers(Map.of(EVENT_TYPE_HEADER, eventConfig.headerValue(), "traceId", context.traceId(), "spanId", context.spanId()))
                .build();
    }

    private Mono<Boolean> send(OutboundMessage outboundMessage) {
        return sender.sendWithPublishConfirms(Mono.just(outboundMessage)).next()
                .flatMap(this::checkIfAcknowledged);
    }

    private Mono<Boolean> checkIfAcknowledged(OutboundMessageResult<OutboundMessage> result) {
        if (result.isAck()) {
            return Mono.just(Boolean.TRUE);
        } else {
            log.warn("Message not Acknowledged !!");
            return Mono.error(new IllegalStateException("Did not receive ACK on message send"));
        }
    }
}

A Rabbit receiver (also from reactor-rabbitmq) later on consumes the message and calls multiple APIs.

@Slf4j
@Component
@AllArgsConstructor
public class EventMessageListener {

    private final EventMessageReceiverProperties eventMessageReceiverProperties;
    private final Receiver eventReceiver;
    private final ConsumeOptions consumeOptions;
    private final Tracer tracer;
    private final CalculationEventHandler calculationEventHandler;

    @EventListener(ApplicationReadyEvent.class)
    public void receiveMessages() {
        eventReceiver.consumeManualAck(eventMessageReceiverProperties.getQueue(), consumeOptions)
                .flatMap(this::handleMessage)
                .doFinally(s -> eventReceiver.close())
                .subscribe();
    }

    private Mono<Void> handleMessage(AcknowledgableDelivery message) {
        var traceId = message.getProperties().getHeaders().get("traceId").toString();
        var spanId = message.getProperties().getHeaders().get("spanId").toString();
        var context = tracer.traceContextBuilder().traceId(traceId).spanId(spanId).build();
        return Mono.defer(() -> calculationEventHandler.handle(message)
                .doOnSuccess(v -> message.ack())
                .onErrorResume(ex -> {
                    log.error("Failed to handle message {}", new String(message.getBody()));
                    log.error("Exception:", ex);
                    message.nack(false);
                    return Mono.empty();
                }))
                .contextWrite(Context.of(TraceContext.class, context));
    }
}

This event is then processed:

@Component
@AllArgsConstructor
public class CalculationEventHandler  {

    private final ObjectMapper objectMapper;
    private final CalculationErrorHandler calculationErrorHandler;
    private final CalculationEventProcessor calculationEventProcessor;
    private final EventMessageReceiverProperties eventMessageReceiverProperties;

    public Mono<Void> handle(AcknowledgableDelivery message) {
        try {
            var event = objectMapper.readValue(message.getBody(), CalculationEvent.class);
            return calculationEventProcessor.process(event)
                    .onErrorResume(t -> calculationErrorHandler.handleError(message, t, event))
                    .then(Mono.empty());
        } catch (IOException e) {
            return Mono.error(e);
        }
    }
}

Processing includes validating the payload using some external APIs which are called in multiple ValidationService implementations. I'm using Flux.mergeDelayError to wait for all the responses from API services to assemble all errors if multiple API calls failed:

@Service
@AllArgsConstructor
public class CalculationValidator {

    private final List<ValidationService> validationServices;

    public Mono<Void> validate(CalculationInput input) {
        return Flux.mergeDelayError(Queues.XS_BUFFER_SIZE, validateWithEachService(input).toArray(Publisher[]::new)).then();
    }

    private List<Mono<Void>> createAttributeRequests(CalculationInput input) {
        return validationServices.stream()
                .map(validationService -> validationService.validate(input))
                .toList();
    }
}

Here's what most of the ValidationService implementations look like:

@Service
public class ThresholdValidatorServiceImpl implements ValidatorService {

    private final WebClient webClient;
    private final String endpoint;
    private final ThresholdInputMapper mapper;

    public ThresholdValidatorServiceImpl(WebClient.Builder builder,
                                        ValidationServiceErrorFilterFactory errorFilterFactory,
                                        ThresholdInputMapper mapper,
                                        @Value("${integration.threshold-validator.url}") String gateway,
                                        @Value("${integration.threshold-validator.endpoint}") String endpoint) {
        this.webClient = builder
                .baseUrl(dpGateway)
                .filter(errorFilterFactory.createFilterFor("threshold-validator"))
                .build();
        this.endpoint = endpoint;
        this.mapper = mapper;
    }

    @Override
    public Mono<JsonNode> get(CalculationInput input) {
        return Mono.just(mapper.map(input))
                .map(body -> webClient.post()
                        .uri(endpoint)
                        .content(MediaType.APPLICATION_JSON)
                        .bodyValue(body)
                        .retrieve()
                        .bodyToMono(Void.class));
    }
}

What I am after: I want to re-use the same traceId from the initial REST request by putting the traceId into the AMQP message properties and after consuming it via the Rabbit Receiver - inject the traceId into the Reactor Context so the API calls in each ValidationService via WebClient use the same traceId from the initial request.

Problem: I am able to save the traceId from the initial request and put it inside the AMQP message properties. After consuming the RabbitMQ message I'm fetching the traceId and writing it to the Reactor's Context using as TraceContext. My understanding is that this context will be used upstream in the reactive pipeline further on where the API calls are made. After doing multiple API calls the WebClient seems to generate a new traceId for each .exchange() which is not the behaviour I am expecting.

Question: Is this even possible to achieve? If yes, what would be the correct approach?

Dependencies used:

  • io.micrometer:micrometer-tracing:1.1.2
  • io.micrometer:context-propagation:1.0.3
  • io.projectreactor.rabbitmq:reactor-rabbitmq:1.5.6
  • org.springframework.boot:spring-boot-starter-webflux:3.1.0

EDIT: Added some code for more clarity, updated descriptions.


Solution

  • Spring boot 3.x uses micrometer tracing. All Spring Boot default configurations works with micrometer Observation API. webflux webclient expects Observation object with key "micrometer.observation" to read any Observation set in current context else it starts a new observation and hence following code does not work

    contextWrite(Context.of(TraceContext.class, context));
    

    Possible solutions to this problem

    1. Use Spring Cloud binder for RabbitMQ. It supports reactive specification including receiving messages from MQ as Flux. There is zero code required to make all your scenarios work. For example, Spring WebFlux will read trace/span from request header and add it to request context. It will add this to outgoing remote calls including web calls and messages (rabbitMQ). It will read these headers on incoming messages and create span from same for any outgoing calls. All your scenarios will work out of box. You can debug this in logs if using MDC (logback etc.) as it adds trace information in MDC context and takes care of most of scenarios of context switching.

    2. Instead of setting TraceContext in context, write specific ids like trace/span/parent etc directly in context and then add them to right header values when making downstream calls. Do not worry about context switches or which thread will execute webclient calls. If you simply set Hooks.enableAutomaticContextPropagation(); in main class then with minimal overhead your context should be passed.

    3. This is bit complex but I could not think of any easier solution. Create a custom io.micrometer.observation.transport.ReceiverContext and create a new Observation on same. You will not have to worry about reading spans from header as default propogator will extract it from message header. Taking inspiration from Spring RabbitMQ Message Listener does

    Create new ReceiverContext class similar to org.springframework.amqp.rabbit.support.micrometer.RabbitMessageReceiverContext
    
    ---
    //Create new observation in receiver
    Observation observation = Observation.createNotStarted(...); // Resolve ObservationRegistry Bean and pass custom receiver context
    //Simply add code between 
    observation.observe(()->//your code) // this will be added to ThreadLocal and Context
    
    // PropagatingReceiverTracingObservationHandler already added to registry handler will handle this observation and extract the required header from message