-1

Background: I have a service running on Spring Boot 3.1.0 which communicates via REST and AMQP. After it is invoked via REST, it publishes the REST payload's content to a RabbitMQ Queue using reactor-rabbitmq and immediately returns a HTTP response.

@RestController
@RequiredArgsConstructor
@RequestMapping("${api.baseurl}")
public class CalculationInitiateController {

    private final RequestMapper requestMapper;
    private final ResponseMapper responseMapper;
    private final CalculationInitializer initializer;

    @ResponseStatus(code = HttpStatus.CREATED)
    @PostMapping("/initiate")
    public Mono<CalculationInitResponseDto> initiateCalculation(@RequestBody @Valid CalculationInitiationRequestDto request) {
        return Mono.just(requestMapper.map(request))
                .flatMap(initializer::initiateCalculation)
                .map(responseMapper::map);
    }
}

Event publishing via reactor-rabbitmq:

@Slf4j
@AllArgsConstructor
@Service
public class CalculationInitializer {

    private final Sender sender;
    private final ObjectMapper objectMapper;
    private final EventPublisherProperties eventPublisherProperties;
    private final OutboundMessageFactory messageFactory;

    public Mono<Boolean> initialize(CalculationEvent calculationEvent) {
        log.info("Publishing internal {} event for calculation: {}", eventType, calculationEvent.request().calculationId());
        var bytePayload = objectMapper.writeValueAsBytes(calculationEvent);
        var outboundMessage = new OutboundMessage("", eventConfig.getRoutingKey(), messageProperties(eventConfig), bytePayload);
        return send(outboundMessage);    
    }

    private AMQP.BasicProperties messageProperties(EventPublisherProperties.EventConfig eventConfig) {
        var context = tracer.currentTraceContext().context();
        return new AMQP.BasicProperties.Builder()
                .correlationId(context.traceId())
                .headers(Map.of(EVENT_TYPE_HEADER, eventConfig.headerValue(), "traceId", context.traceId(), "spanId", context.spanId()))
                .build();
    }

    private Mono<Boolean> send(OutboundMessage outboundMessage) {
        return sender.sendWithPublishConfirms(Mono.just(outboundMessage)).next()
                .flatMap(this::checkIfAcknowledged);
    }

    private Mono<Boolean> checkIfAcknowledged(OutboundMessageResult<OutboundMessage> result) {
        if (result.isAck()) {
            return Mono.just(Boolean.TRUE);
        } else {
            log.warn("Message not Acknowledged !!");
            return Mono.error(new IllegalStateException("Did not receive ACK on message send"));
        }
    }
}

A Rabbit receiver (also from reactor-rabbitmq) later on consumes the message and calls multiple APIs.

@Slf4j
@Component
@AllArgsConstructor
public class EventMessageListener {

    private final EventMessageReceiverProperties eventMessageReceiverProperties;
    private final Receiver eventReceiver;
    private final ConsumeOptions consumeOptions;
    private final Tracer tracer;
    private final CalculationEventHandler calculationEventHandler;

    @EventListener(ApplicationReadyEvent.class)
    public void receiveMessages() {
        eventReceiver.consumeManualAck(eventMessageReceiverProperties.getQueue(), consumeOptions)
                .flatMap(this::handleMessage)
                .doFinally(s -> eventReceiver.close())
                .subscribe();
    }

    private Mono<Void> handleMessage(AcknowledgableDelivery message) {
        var traceId = message.getProperties().getHeaders().get("traceId").toString();
        var spanId = message.getProperties().getHeaders().get("spanId").toString();
        var context = tracer.traceContextBuilder().traceId(traceId).spanId(spanId).build();
        return Mono.defer(() -> calculationEventHandler.handle(message)
                .doOnSuccess(v -> message.ack())
                .onErrorResume(ex -> {
                    log.error("Failed to handle message {}", new String(message.getBody()));
                    log.error("Exception:", ex);
                    message.nack(false);
                    return Mono.empty();
                }))
                .contextWrite(Context.of(TraceContext.class, context));
    }
}

This event is then processed:

@Component
@AllArgsConstructor
public class CalculationEventHandler  {

    private final ObjectMapper objectMapper;
    private final CalculationErrorHandler calculationErrorHandler;
    private final CalculationEventProcessor calculationEventProcessor;
    private final EventMessageReceiverProperties eventMessageReceiverProperties;

    public Mono<Void> handle(AcknowledgableDelivery message) {
        try {
            var event = objectMapper.readValue(message.getBody(), CalculationEvent.class);
            return calculationEventProcessor.process(event)
                    .onErrorResume(t -> calculationErrorHandler.handleError(message, t, event))
                    .then(Mono.empty());
        } catch (IOException e) {
            return Mono.error(e);
        }
    }
}

Processing includes validating the payload using some external APIs which are called in multiple ValidationService implementations. I'm using Flux.mergeDelayError to wait for all the responses from API services to assemble all errors if multiple API calls failed:

@Service
@AllArgsConstructor
public class CalculationValidator {

    private final List<ValidationService> validationServices;

    public Mono<Void> validate(CalculationInput input) {
        return Flux.mergeDelayError(Queues.XS_BUFFER_SIZE, validateWithEachService(input).toArray(Publisher[]::new)).then();
    }

    private List<Mono<Void>> createAttributeRequests(CalculationInput input) {
        return validationServices.stream()
                .map(validationService -> validationService.validate(input))
                .toList();
    }
}

Here's what most of the ValidationService implementations look like:

@Service
public class ThresholdValidatorServiceImpl implements ValidatorService {

    private final WebClient webClient;
    private final String endpoint;
    private final ThresholdInputMapper mapper;

    public ThresholdValidatorServiceImpl(WebClient.Builder builder,
                                        ValidationServiceErrorFilterFactory errorFilterFactory,
                                        ThresholdInputMapper mapper,
                                        @Value("${integration.threshold-validator.url}") String gateway,
                                        @Value("${integration.threshold-validator.endpoint}") String endpoint) {
        this.webClient = builder
                .baseUrl(dpGateway)
                .filter(errorFilterFactory.createFilterFor("threshold-validator"))
                .build();
        this.endpoint = endpoint;
        this.mapper = mapper;
    }

    @Override
    public Mono<JsonNode> get(CalculationInput input) {
        return Mono.just(mapper.map(input))
                .map(body -> webClient.post()
                        .uri(endpoint)
                        .content(MediaType.APPLICATION_JSON)
                        .bodyValue(body)
                        .retrieve()
                        .bodyToMono(Void.class));
    }
}

What I am after: I want to re-use the same traceId from the initial REST request by putting the traceId into the AMQP message properties and after consuming it via the Rabbit Receiver - inject the traceId into the Reactor Context so the API calls in each ValidationService via WebClient use the same traceId from the initial request.

Problem: I am able to save the traceId from the initial request and put it inside the AMQP message properties. After consuming the RabbitMQ message I'm fetching the traceId and writing it to the Reactor's Context using as TraceContext. My understanding is that this context will be used upstream in the reactive pipeline further on where the API calls are made. After doing multiple API calls the WebClient seems to generate a new traceId for each .exchange() which is not the behaviour I am expecting.

Question: Is this even possible to achieve? If yes, what would be the correct approach?

Dependencies used:

  • io.micrometer:micrometer-tracing:1.1.2
  • io.micrometer:context-propagation:1.0.3
  • io.projectreactor.rabbitmq:reactor-rabbitmq:1.5.6
  • org.springframework.boot:spring-boot-starter-webflux:3.1.0

EDIT: Added some code for more clarity, updated descriptions.

Airidas36
  • 21
  • 4
  • your first question yes, your second question, what is "correct" is subjective and we are not a code giving service. You have not provided a single line of code se we have no idea what your solution looks like so it is impossible for us to know what is the "correct" solution for you. Good luck – Toerktumlare Jul 07 '23 at 23:50
  • @Toerktumlare I updated the question with some code for context – Airidas36 Jul 10 '23 at 06:29

2 Answers2

1

Spring boot 3.x uses micrometer tracing. All Spring Boot default configurations works with micrometer Observation API. webflux webclient expects Observation object with key "micrometer.observation" to read any Observation set in current context else it starts a new observation and hence following code does not work

contextWrite(Context.of(TraceContext.class, context));

Possible solutions to this problem

  1. Use Spring Cloud binder for RabbitMQ. It supports reactive specification including receiving messages from MQ as Flux. There is zero code required to make all your scenarios work. For example, Spring WebFlux will read trace/span from request header and add it to request context. It will add this to outgoing remote calls including web calls and messages (rabbitMQ). It will read these headers on incoming messages and create span from same for any outgoing calls. All your scenarios will work out of box. You can debug this in logs if using MDC (logback etc.) as it adds trace information in MDC context and takes care of most of scenarios of context switching.

  2. Instead of setting TraceContext in context, write specific ids like trace/span/parent etc directly in context and then add them to right header values when making downstream calls. Do not worry about context switches or which thread will execute webclient calls. If you simply set Hooks.enableAutomaticContextPropagation(); in main class then with minimal overhead your context should be passed.

  3. This is bit complex but I could not think of any easier solution. Create a custom io.micrometer.observation.transport.ReceiverContext and create a new Observation on same. You will not have to worry about reading spans from header as default propogator will extract it from message header. Taking inspiration from Spring RabbitMQ Message Listener does

Create new ReceiverContext class similar to org.springframework.amqp.rabbit.support.micrometer.RabbitMessageReceiverContext

---
//Create new observation in receiver
Observation observation = Observation.createNotStarted(...); // Resolve ObservationRegistry Bean and pass custom receiver context
//Simply add code between 
observation.observe(()->//your code) // this will be added to ThreadLocal and Context

// PropagatingReceiverTracingObservationHandler already added to registry handler will handle this observation and extract the required header from message
  • Thanks, I managed to get it working by following your 3rd solution. I also had to add `contextWrite.(Context.of(ObservationThreadLocalAccessor.KEY, observation))` at the end of `handleMessage()` method in the `EventMessageListener` class – Airidas36 Jul 12 '23 at 07:49
0

I managed to get it working by following @Gaurav advice and creating my own custom ReceiverContext which looks like this:

public class RabbitReceiverContext extends ReceiverContext<AcknowledgableDelivery> {

    private final String queue;
    private final String listenerId;

    public RabbitReceiverContext(AcknowledgableDelivery message, String queueName, String listenerId) {
        super((carrier, key) -> carrier.getProperties().getHeaders().getOrDefault(key, "").toString());
        setCarrier(message);
        this.listenerId = listenerId;
        this.queue = queueName;
        setRemoteServiceName("RabbitMQ");
    }

    public String getListenerId() {
        return this.listenerId;
    }

    public String getSource() {
        return this.queue;
    }
}

Then, I am able to pass it on to the Observation that I am creating in the EventMessageListener handleMessage() method, note - it did not work out of the box, I also had to add .contextWrite(Context.of(ObservationThreadLocalAccessor.KEY, observation)) for it to work:

private Mono<Void> handleMessage(AcknowledgableDelivery message) {
    var observation = Observation.createNotStarted(
            "calculation",
            () -> new RabbitReceiverContext(message, eventMessageReceiverProperties.getQueue(), eventReceiver.toString()), 
            observationRegistry
    );
    return observation.observe(() -> 
            Mono.defer(() -> eventMessageDispatcher.dispatch(message)
                            .doOnSuccess(v -> message.ack())
                            .onErrorResume(ex -> {
                                log.error("Failed to handle message {}", new String(message.getBody()));
                                log.error("Exception:", ex);
                                message.nack(false);
                                return Mono.empty();
                             })
            ).contextWrite(Context.of(ObservationThreadLocalAccessor.KEY, observation))
    );
}

One more thing to note - I also had to adjust the message to publishing to construct a traceparent header, because micrometer is using W3C Context Propagation by default. Header construction is done like this in the CalculationInitializer:

private AMQP.BasicProperties messageProperties(EventPublisherProperties.EventConfig eventConfig) {
    var traceId = tracer.currentTraceContext().context().traceId();
    var spanId = tracer.currentTraceContext().context().spanId();
    var traceParent = String.format("00-%s-%s-00", traceId, spanId);
    return new AMQP.BasicProperties.Builder()
            .headers(Map.of(EVENT_TYPE_HEADER, eventConfig.headerValue(), "traceparent", traceParent))
            .build();
}

After these changes were implemented every WebClient exchange() is automatically propagating the traceId from the initial request.

Airidas36
  • 21
  • 4