Search code examples
javamultithreadingslf4jmdc

Append correlation id to every log java


I need to append correlation id to logs from every request I get.

Here is my filter. It works good until I get to async block like runAsync().

logs

I read about MDC and how it use ThreadLocal but can't understand how to use it in async because it uses ForkJoinPool.

@Component
public class Slf4jFilter extends OncePerRequestFilter {

private static final String CORRELATION_ID_HEADER_NAME = "correlation-id";
private static final String CORRELATION_ID_LOG_VAR_NAME = "correlationId";

@Override
protected void doFilterInternal(HttpServletRequest request,
    HttpServletResponse response, FilterChain chain)
throws ServletException, IOException {
    try {
        ofNullable(request.getHeader(CORRELATION_ID_HEADER_NAME)).ifPresent(correlationId -> MDC
            .put(CORRELATION_ID_LOG_VAR_NAME, correlationId));
        chain.doFilter(request, response);
    }finally {
        removeCorrelationId();
    }
}

protected void removeCorrelationId() {
    MDC.remove(CORRELATION_ID_LOG_VAR_NAME);
}
}

logback.xml

<configuration>
<appender name="stdout" class="ch.qos.logback.core.ConsoleAppender">
    <encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
        <pattern>%d{dd-MM-yyyy HH:mm:ss.SSS} [%thread %X{correlationId}] %-5level %logger{36} - %msg %n</pattern>
    </encoder>
</appender>
<root level="INFO">
    <appender-ref ref="stdout" />
</root>

Solution

  • Here's the solution I ended up. Thanks to @M.Prokhorov

    Create a class called MdcRetention inside your project.

    public final class MdcRetention {
    
    public static Runnable wrap(final Runnable delegate) {
        return new MdcRetainingRunnable() {
            @Override
            protected void runInContext() {
                delegate.run();
            }
        };
    }
    
    private static abstract class MdcRetentionSupport {
        protected final Map<String, String> originalMdc;
    
        protected MdcRetentionSupport() {
            Map<String, String> originalMdc = MDC.getCopyOfContextMap();
            this.originalMdc = originalMdc == null ? Collections.emptyMap() : originalMdc;
        }
    }
    
    public static abstract class MdcRetainingRunnable extends MdcRetentionSupport implements Runnable {
    
        @Override
        public final void run() {
            Map<String, String> currentMdc = MDC.getCopyOfContextMap();
            MDC.setContextMap(originalMdc);
            try {
                runInContext();
            } finally {
                MDC.setContextMap(currentMdc);
            }
        }
    
        abstract protected void runInContext();
    }}
    

    Then wrap your Runnable inside runAsync() block using static method MdcRetention.wrap()

    Before: runAsync(() -> someMethod());

    After: runAsync(wrap(() -> someMethod()));