Search code examples
javaspringspring-integrationamqpspring-amqp

Manual ACK for AggregatingMessageHandler


I'm trying to build integration scenario like this Rabbit -> AmqpInboundChannelAdapter(AcknowledgeMode.MANUAL) -> DirectChannel -> AggregatingMessageHandler -> DirectChannel -> AmqpOutboundEndpoint.

I want to aggregate messages in-memory and release it if I aggregate 10 messages, or if timeout of 10 seconds is reached. I suppose this config is OK:

@Bean
@ServiceActivator(inputChannel = "amqpInputChannel")
public MessageHandler aggregator(){
    AggregatingMessageHandler aggregatingMessageHandler = new AggregatingMessageHandler(new DefaultAggregatingMessageGroupProcessor(), new SimpleMessageStore(10));
    aggregatingMessageHandler.setCorrelationStrategy(new HeaderAttributeCorrelationStrategy(AmqpHeaders.CORRELATION_ID));
    //default false
    aggregatingMessageHandler.setExpireGroupsUponCompletion(true);  //when grp released (using strategy), remove group so new messages in same grp create new group
    aggregatingMessageHandler.setSendPartialResultOnExpiry(true);   //when expired because timeout and not because of strategy, still send messages grouped so far
    aggregatingMessageHandler.setGroupTimeoutExpression(new ValueExpression<>(TimeUnit.SECONDS.toMillis(10)));  //timeout after X

    //timeout is checked only when new message arrives!!
    aggregatingMessageHandler.setReleaseStrategy(new TimeoutCountSequenceSizeReleaseStrategy(10, TimeUnit.SECONDS.toMillis(10)));
    aggregatingMessageHandler.setOutputChannel(amqpOutputChannel());
    return aggregatingMessageHandler;
}

Now, my question is - is there any easier way to manualy ack messages except creating my own implementation of AggregatingMessageHandler in this way:

public class ManualAckAggregatingMessageHandler extends AbstractCorrelatingMessageHandler {
        ...

    private void ackMessage(Channel channel, Long deliveryTag){
        try {
            Assert.notNull(channel, "Channel must be provided");
            Assert.notNull(deliveryTag, "Delivery tag must be provided");
            channel.basicAck(deliveryTag, false);
        }
        catch (IOException e) {
            throw new MessagingException("Cannot ACK message", e);
        }
    }

    @Override
    protected void afterRelease(MessageGroup messageGroup, Collection<Message<?>> completedMessages) {
        Object groupId = messageGroup.getGroupId();
        MessageGroupStore messageStore = getMessageStore();
        messageStore.completeGroup(groupId);

        messageGroup.getMessages().forEach(m -> {
            Channel channel = (Channel)m.getHeaders().get(AmqpHeaders.CHANNEL);
            Long deliveryTag = (Long)m.getHeaders().get(AmqpHeaders.DELIVERY_TAG);
            ackMessage(channel, deliveryTag);
        });

        if (this.expireGroupsUponCompletion) {
            remove(messageGroup);
        }
        else {
            if (messageStore instanceof SimpleMessageStore) {
                ((SimpleMessageStore) messageStore).clearMessageGroup(groupId);
            }
            else {
                messageStore.removeMessagesFromGroup(groupId, messageGroup.getMessages());
            }
        }
    }
}

UPDATE

I managed to do it after your help. Most important parts: Connection factory must have factory.setPublisherConfirms(true). AmqpOutboundEndpoint must have this two settings: outboundEndpoint.setConfirmAckChannel(manualAckChannel()) and outboundEndpoint.setConfirmCorrelationExpressionString("#root"), and this is implementation of rest of classes:

public class ManualAckPair {
    private Channel channel;
    private Long deliveryTag;

    public ManualAckPair(Channel channel, Long deliveryTag) {
        this.channel = channel;
        this.deliveryTag = deliveryTag;
    }

    public void basicAck(){
        try {
            this.channel.basicAck(this.deliveryTag, false);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }
}

public abstract class AbstractManualAckAggregatingMessageGroupProcessor extends AbstractAggregatingMessageGroupProcessor {
    public static final String MANUAL_ACK_PAIRS = PREFIX + "manualAckPairs";

    @Override
    protected Map<String, Object> aggregateHeaders(MessageGroup group) {
        Map<String, Object> aggregatedHeaders = super.aggregateHeaders(group);
        List<ManualAckPair> manualAckPairs = new ArrayList<>();
        group.getMessages().forEach(m -> {
            Channel channel = (Channel)m.getHeaders().get(AmqpHeaders.CHANNEL);
            Long deliveryTag = (Long)m.getHeaders().get(AmqpHeaders.DELIVERY_TAG);
            manualAckPairs.add(new ManualAckPair(channel, deliveryTag));
        });
        aggregatedHeaders.put(MANUAL_ACK_PAIRS, manualAckPairs);
        return aggregatedHeaders;
    }
}

and

@Service
public class ManualAckServiceActivator {

    @ServiceActivator(inputChannel = "manualAckChannel")
public void handle(@Header(MANUAL_ACK_PAIRS) List<ManualAckPair> manualAckPairs) {
    manualAckPairs.forEach(manualAckPair -> {
        manualAckPair.basicAck();
    });
}
}

Solution

  • Right, you don't need such a complex logic for the aggregator.

    You simply can acknowledge them after the aggregator release - in the service activator in between aggregator and that AmqpOutboundEndpoint.

    And right you have to use there basicAck() with the multiple flag to true:

    @param multiple true to acknowledge all messages up to and
    

    Well, for that purpose you definitely need a custom MessageGroupProcessor to extract the highest AmqpHeaders.DELIVERY_TAG for the whole batch and set it as a header for the output aggregated message.

    You might just extend DefaultAggregatingMessageGroupProcessor and override its aggregateHeaders():

    /**
     * This default implementation simply returns all headers that have no conflicts among the group. An absent header
     * on one or more Messages within the group is not considered a conflict. Subclasses may override this method with
     * more advanced conflict-resolution strategies if necessary.
     *
     * @param group The message group.
     * @return The aggregated headers.
     */
    protected Map<String, Object> aggregateHeaders(MessageGroup group) {