Search code examples
javaslf4jfork-joinmdc

How to use MDC with ForkJoinPool?


Following up on How to use MDC with thread pools? how can one use MDC with a ForkJoinPool? Specifically, I how can one wrap a ForkJoinTask so MDC values are set before executing a task?


Solution

  • The following seems to work for me:

    import java.lang.Thread.UncaughtExceptionHandler;
    import java.util.Map;
    import java.util.concurrent.ForkJoinPool;
    import java.util.concurrent.ForkJoinTask;
    import java.util.concurrent.atomic.AtomicReference;
    import org.slf4j.MDC;
    
    /**
     * A {@link ForkJoinPool} that inherits MDC contexts from the thread that queues a task.
     *
     * @author Gili Tzabari
     */
    public final class MdcForkJoinPool extends ForkJoinPool
    {
        /**
         * Creates a new MdcForkJoinPool.
         *
         * @param parallelism the parallelism level. For default value, use {@link java.lang.Runtime#availableProcessors}.
         * @param factory     the factory for creating new threads. For default value, use
         *                    {@link #defaultForkJoinWorkerThreadFactory}.
         * @param handler     the handler for internal worker threads that terminate due to unrecoverable errors encountered
         *                    while executing tasks. For default value, use {@code null}.
         * @param asyncMode   if true, establishes local first-in-first-out scheduling mode for forked tasks that are never
         *                    joined. This mode may be more appropriate than default locally stack-based mode in applications
         *                    in which worker threads only process event-style asynchronous tasks. For default value, use
         *                    {@code false}.
         * @throws IllegalArgumentException if parallelism less than or equal to zero, or greater than implementation limit
         * @throws NullPointerException     if the factory is null
         * @throws SecurityException        if a security manager exists and the caller is not permitted to modify threads
         *                                  because it does not hold
         *                                  {@link java.lang.RuntimePermission}{@code ("modifyThread")}
         */
        public MdcForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler,
            boolean asyncMode)
        {
            super(parallelism, factory, handler, asyncMode);
        }
    
        @Override
        public void execute(ForkJoinTask<?> task)
        {
            // See http://stackoverflow.com/a/19329668/14731
            super.execute(wrap(task, MDC.getCopyOfContextMap()));
        }
    
        @Override
        public void execute(Runnable task)
        {
            // See http://stackoverflow.com/a/19329668/14731
            super.execute(wrap(task, MDC.getCopyOfContextMap()));
        }
    
        private <T> ForkJoinTask<T> wrap(ForkJoinTask<T> task, Map<String, String> newContext)
        {
            return new ForkJoinTask<T>()
            {
                private static final long serialVersionUID = 1L;
                /**
                 * If non-null, overrides the value returned by the underlying task.
                 */
                private final AtomicReference<T> override = new AtomicReference<>();
    
                @Override
                public T getRawResult()
                {
                    T result = override.get();
                    if (result != null)
                        return result;
                    return task.getRawResult();
                }
    
                @Override
                protected void setRawResult(T value)
                {
                    override.set(value);
                }
    
                @Override
                protected boolean exec()
                {
                    // According to ForkJoinTask.fork() "it is a usage error to fork a task more than once unless it has completed
                    // and been reinitialized". We therefore assume that this method does not have to be thread-safe.
                    Map<String, String> oldContext = beforeExecution(newContext);
                    try
                    {
                        task.invoke();
                        return true;
                    }
                    finally
                    {
                        afterExecution(oldContext);
                    }
                }
            };
        }
    
        private Runnable wrap(Runnable task, Map<String, String> newContext)
        {
            return () ->
            {
                Map<String, String> oldContext = beforeExecution(newContext);
                try
                {
                    task.run();
                }
                finally
                {
                    afterExecution(oldContext);
                }
            };
        }
    
        /**
         * Invoked before running a task.
         *
         * @param newValue the new MDC context
         * @return the old MDC context
         */
        private Map<String, String> beforeExecution(Map<String, String> newValue)
        {
            Map<String, String> previous = MDC.getCopyOfContextMap();
            if (newValue == null)
                MDC.clear();
            else
                MDC.setContextMap(newValue);
            return previous;
        }
    
        /**
         * Invoked after running a task.
         *
         * @param oldValue the old MDC context
         */
        private void afterExecution(Map<String, String> oldValue)
        {
            if (oldValue == null)
                MDC.clear();
            else
                MDC.setContextMap(oldValue);
        }
    }
    

    and

    import java.util.Map;
    import java.util.concurrent.CountedCompleter;
    import org.slf4j.MDC;
    
    /**
     * A {@link CountedCompleter} that inherits MDC contexts from the thread that queues a task.
     *
     * @author Gili Tzabari
     * @param <T> The result type returned by this task's {@code get} method
     */
    public abstract class MdcCountedCompleter<T> extends CountedCompleter<T>
    {
        private static final long serialVersionUID = 1L;
        private final Map<String, String> newContext;
    
        /**
         * Creates a new MdcCountedCompleter instance using the MDC context of the current thread.
         */
        protected MdcCountedCompleter()
        {
            this(null);
        }
    
        /**
         * Creates a new MdcCountedCompleter instance using the MDC context of the current thread.
         *
         * @param completer this task's completer; {@code null} if none
         */
        protected MdcCountedCompleter(CountedCompleter<?> completer)
        {
            super(completer);
            this.newContext = MDC.getCopyOfContextMap();
        }
    
        /**
         * The main computation performed by this task.
         */
        protected abstract void computeWithContext();
    
        @Override
        public final void compute()
        {
            Map<String, String> oldContext = beforeExecution(newContext);
            try
            {
                computeWithContext();
            }
            finally
            {
                afterExecution(oldContext);
            }
        }
    
        /**
         * Invoked before running a task.
         *
         * @param newValue the new MDC context
         * @return the old MDC context
         */
        private Map<String, String> beforeExecution(Map<String, String> newValue)
        {
            Map<String, String> previous = MDC.getCopyOfContextMap();
            if (newValue == null)
                MDC.clear();
            else
                MDC.setContextMap(newValue);
            return previous;
        }
    
        /**
         * Invoked after running a task.
         *
         * @param oldValue the old MDC context
         */
        private void afterExecution(Map<String, String> oldValue)
        {
            if (oldValue == null)
                MDC.clear();
            else
                MDC.setContextMap(oldValue);
        }
    }
    
    1. Run your tasks against MdcForkJoinPool instead of the common ForkJoinPool.
    2. Extend MdcCountedCompleter instead of CountedCompleter.