Search code examples
javamultithreadingsortingmergesortfork-join

How to implement a multi-threaded MergeSort in Java


Most examples I find of Merge Sort run in a single thread. This defeats some of the advantage of using the Merge Sort algorithm in the first place. Can someone show the proper way to write a Merge Sort algorithm in java using multi-threading.

The solution should use features from the latest version of java as applicable. Many Solutions already on Stackoverflow use plain Threads. I'm looking for a solution demonstrating a ForkJoin with a RecursiveTask, which would seem to be the primary use case the RecursiveTask class was intended for.

Emphasis should be on showing an algorithm with superior performance characteristics including both time and space complexity where possible.

NOTE: Neither of the proposed duplicate questions apply since neither provide a solution using Recursive Task which was specifically what this question was asking for.


Solution

  • The most convenient multi-threading paradigm for a Merge Sort is the fork-join paradigm. This is provided from Java 8 and later. The following code demonstrates a Merge Sort using a fork-join.

    import java.util.*;
    import java.util.concurrent.*;
    
    public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
        private List<N> elements;
    
        public MergeSort(List<N> elements) {
            this.elements = new ArrayList<>(elements);
        }
    
        @Override
        protected List<N> compute() {
            if(this.elements.size() <= 1)
                return this.elements;
            else {
                final int pivot = this.elements.size() / 2;
                MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
                MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
    
                leftTask.fork();
                rightTask.fork();
    
                List<N> left = leftTask.join();
                List<N> right = rightTask.join();
    
                return merge(left, right);
            }
        }
    
        private List<N> merge(List<N> left, List<N> right) {
            List<N> sorted = new ArrayList<>();
            while(!left.isEmpty() || !right.isEmpty()) {
                if(left.isEmpty())
                    sorted.add(right.remove(0));
                else if(right.isEmpty())
                    sorted.add(left.remove(0));
                else {
                    if( left.get(0).compareTo(right.get(0)) < 0 )
                        sorted.add(left.remove(0));
                    else
                        sorted.add(right.remove(0));
                }
            }
    
            return sorted;
        }
    
        public static void main(String[] args) {
            ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
            List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,10,1)));
            System.out.println("result: " + result);
        }
    }
    

    While much less straight forward the following varient of the code eliminates the excessive copying of the ArrayList. The initial unsorted list is only created once, and the calls to sublist do not need to perform any copying themselves. Before we would copy the array list each time the algorithm forked. Also, now, when merging lists instead of creating a new list and copying values in it each time we reuse the left list and insert our values into there. By avoiding the extra copy step we improve performance. We use a LinkedList here because inserts are rather cheap compared to an ArrayList. We also eliminate the call to remove, which can be expensive on an ArrayList as well.

    import java.util.*;
    import java.util.concurrent.*;
    
    public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
        private List<N> elements;
    
        public MergeSort(List<N> elements) {
            this.elements = elements;
        }
    
        @Override
        protected List<N> compute() {
            if(this.elements.size() <= 1)
                return new LinkedList<>(this.elements);
            else {
                final int pivot = this.elements.size() / 2;
                MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
                MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
    
                leftTask.fork();
                rightTask.fork();
    
                List<N> left = leftTask.join();
                List<N> right = rightTask.join();
    
                return merge(left, right);
            }
        }
    
        private List<N> merge(List<N> left, List<N> right) {
            int leftIndex = 0;
            int rightIndex = 0;
            while(leftIndex < left.size() || rightIndex < right.size()) {
                if(leftIndex >= left.size())
                    left.add(leftIndex++, right.get(rightIndex++));
                else if(rightIndex >= right.size())
                    return left;
                else {
                    if( left.get(leftIndex).compareTo(right.get(rightIndex)) < 0 )
                        leftIndex++;
                    else
                        left.add(leftIndex++, right.get(rightIndex++));
                }
            }
    
            return left;
        }
    
        public static void main(String[] args) {
            ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
            List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
            System.out.println("result: " + result);
        }
    }
    

    We can also improve the code one step further by using iterators instead of calling get directly when performing the merge. The reason for this is that get on a LinkedList by index has poor time performance (linear) so by using an iterator we eliminate slow-down caused by internally iterating the linked list on each get. The call to next on an iterator is constant time as opposed to linear time for the call to get. The following code is modified to use iterators instead.

    import java.util.*;
    import java.util.concurrent.*;
    
    public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
        private List<N> elements;
    
        public MergeSort(List<N> elements) {
            this.elements = elements;
        }
    
        @Override
        protected List<N> compute() {
            if(this.elements.size() <= 1)
                return new LinkedList<>(this.elements);
            else {
                final int pivot = this.elements.size() / 2;
                MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
                MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
    
                leftTask.fork();
                rightTask.fork();
    
                List<N> left = leftTask.join();
                List<N> right = rightTask.join();
    
                return merge(left, right);
            }
        }
    
        private List<N> merge(List<N> left, List<N> right) {
            ListIterator<N> leftIter = left.listIterator();
            ListIterator<N> rightIter = right.listIterator();
            while(leftIter.hasNext() || rightIter.hasNext()) {
                if(!leftIter.hasNext()) {
                    leftIter.add(rightIter.next());
                    rightIter.remove();
                }
                else if(!rightIter.hasNext())
                    return left;
                else {
                    N rightElement = rightIter.next();
                    if( leftIter.next().compareTo(rightElement) < 0 )
                        rightIter.previous();
                    else {
                        leftIter.previous();
                        leftIter.add(rightElement);
                    }
                }
            }
    
            return left;
        }
    
        public static void main(String[] args) {
            ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
            List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
            System.out.println("result: " + result);
        }
    }
    

    Finally the most complex versions of the code, this iteration uses an entirely in-place operation. Only the initial ArrayList is created and no additional collections are ever created. As such the logic is particularly difficult to follow (so i saved it for last). But should be as close to an ideal implementation as we can get.

    import java.util.*;
    import java.util.concurrent.*;
    
    public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
        private List<N> elements;
    
        public MergeSort(List<N> elements) {
            this.elements = elements;
        }
    
        @Override
        protected List<N> compute() {
            if(this.elements.size() <= 1)
                return this.elements;
            else {
                final int pivot = this.elements.size() / 2;
                MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
                MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
    
                leftTask.fork();
                rightTask.fork();
    
                List<N> left = leftTask.join();
                List<N> right = rightTask.join();
    
                merge(left, right);
                return this.elements;
            }
        }
    
        private void merge(List<N> left, List<N> right) {
            int leftIndex = 0;
            int rightIndex = 0;
            while(leftIndex < left.size() ) {
                if(rightIndex == 0) {
                    if( left.get(leftIndex).compareTo(right.get(rightIndex)) > 0 ) {
                        swap(left, leftIndex++, right, rightIndex++);
                    } else {
                        leftIndex++;
                    }
                } else {
                    if(rightIndex >= right.size()) {
                        if(right.get(0).compareTo(left.get(left.size() - 1)) < 0 )
                            merge(left, right);
                        else
                            return;
                    }
                    else if( right.get(0).compareTo(right.get(rightIndex)) < 0 ) {
                        swap(left, leftIndex++, right, 0);
                    } else {
                        swap(left, leftIndex++, right, rightIndex++);
                    }
                }
            }
    
            if(rightIndex < right.size() && rightIndex != 0)
                merge(right.subList(0, rightIndex), right.subList(rightIndex, right.size()));
        }
    
        private void swap(List<N> left, int leftIndex, List<N> right, int rightIndex) {
            //N leftElement = left.get(leftIndex);
            left.set(leftIndex, right.set(rightIndex, left.get(leftIndex)));
        }
    
        public static void main(String[] args) {
            ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
            List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(new ArrayList<>(Arrays.asList(5,9,8,7,6,1,2,3,4))));
            System.out.println("result: " + result);
        }
    }