Search code examples
javamultithreadingconcurrencyexecutorserviceforkjoinpool

How do I run something parallel in Java?


I am trying to print all possible combinations within a range. For example if my lowerBound is 3 and my max is 5, I want the following combinations: (5,4 - 5,3 - 4,3). I've implemented this with the helper() function found below.

Of course if my max is very big this is a lot of combinations and this will take a long time. That's why I'm trying to implement a ForkJoinPool, so that the tasks run parallel. For this I create a new ForkJoinPool. Then I loop over all possible values of r(Where r is the amount of numbers in the combination, in the above example r=3). For every value of r I create a new HelperCalculator, which extends RecursiveTask<Void>. In there I recursively call the helper() function. Every time I call this I create a new HelperCalculator and i use .fork() on that.

The problem is as follows. It is not correctly generating all possible combinations. It actually generates no combinations at all. I've tried adding calculator.join() after calculator.fork(), but that just goes on infinitely till I get an OutOfMemory error.

Obviously there is something I'm misunderstanding about the ForkJoinPool, but I can't see what anymore, after trying for days.

My main function:

            ForkJoinPool pool = (ForkJoinPool) Executors.newWorkStealingPool();
            for (int r = 1; r < 25; r++) {
                int lowerBound = 7;
                int[] data = new int[r];
                int max = 25;
                calculator = new HelperCalculator(data, 0, max, 0, s, n, lowerBound);
                pool.execute(calculator);
                calculator.join();
            }
            pool.shutdown();

The HelperCalculator class:

    protected Void compute() {
        helper(data, end, start, index, s, lowerBound);
        return null;
    }

    //Generate all possible combinations
    public void helper(int[] data , int end, int start, int index,int s, int lowerBound) {
        //If the array is filled, print it
        if (index == data.length) {
                System.out.println(Arrays.toString(data));
        } else if (start >= end) {
            data[index] = start;
            if(data[0] >= lowerBound) {
                HelperCalculator calculator = new HelperCalculator(data,end, start-1, index+1, s, n, lowerBound);
                calculator.fork();
                calculators.add(calculator);
                HelperCalculator calculator2 = new HelperCalculator(data, end, start-1, index, s, n, lowerBound);
                calculator2.fork();
                calculators.add(calculator2);
            }
        }

How do I make every HelperCalculator run parallel, so that there are 23 running at the same time using a ForkJoinPool? Or should I perhaps use a different solution?

I've tried calling join() and isDone() on the calculators list, but then it doesn't wait for it to finish properly and the program just exits.

Because someone doesn't understand the algorithm, here it is:

    public static void main(String[] args) {
            for(int r = 3; r > 0; r--) {
                int[] data = new int[r];
                helper(data, 0, 2, 0);
            }
    }

    public static void helper(int[] data , int end, int start, int index) {
        if (index == data.length) {
            System.out.println(Arrays.toString(data));
        } else if (start >= end) {
            data[index] = start;
                helper(data, end, start - 1, index + 1);
                helper(data, end, start - 1, index);
            }
        }
    }

The output of this is:

[2, 1, 0]
[2, 1]
[2, 0]
[1, 0]
[2]
[1]
[0]

Solution

  • Some of the tasks you are forking attempt to use the same array for evaluating different combinations. You can solve the issue by creating a distinct array for each task or by limiting the parallelism to those tasks which already have an array on their own, i.e. those with different length.

    But there’s another possibility; don’t use arrays at all. You can store combinations into int values, as each int value is a combination of bits. This does not only save a lot of memory, but you can also easily iterate over all possible combinations by just incrementing the value, as iterating over all int numbers also iterates over all possible bit combinations¹. The only thing we need to implement is generating the right string for a particular int value by interpreting the bits as numbers according to their position.

    For a first attempt, we can take the easy way and use already existing classes:

    public static void main(String[] args) {
        long t0 = System.nanoTime();
        combinations(10, 25);
        long t1 = System.nanoTime();
        System.out.println((t1 - t0)/1_000_000+" ms");
        System.out.flush();
    }
    static void combinations(int start, int end) {
        for(int i = 1, stop = (1 << (end - start)) - 1; i <= stop; i++) {
            System.out.println(
                BitSet.valueOf(new long[]{i}).stream()
                      .mapToObj(b -> String.valueOf(b + start))
                      .collect(Collectors.joining(", ", "[", "]"))
            );
        }
    }
    

    The method uses an exclusive end, so for your example, you have to call it like combinations(0, 3) and it will print

    [0]
    [1]
    [0, 1]
    [2]
    [0, 2]
    [1, 2]
    [0, 1, 2]
    3 ms
    

    of course, timing may vary

    For the combinations(10, 25) example above, it prints all combinations, followed by 3477 ms on my machine. This sounds like an opportunity to optimize, but we should first think about which operations impose which costs.

    Iterating over the combinations has been reduced to a trivial operation here. Creating the string is an order of magnitude more expensive. But this is still nothing compared to the actual printing which includes a data transfer to the operating system and, depending on the system, the actual rendering may add to our time. Since this is done while holding a lock within PrintStream, all threads attempting to print at the same time would be blocked, making it a nonparallelizable operation.

    Let’s identify the fraction of the cost, by creating a new PrintStream, disabling the auto-flush on line breaks and using an insanely large buffer, capable of holding the entire output:

    public static void main(String[] args) {
        System.setOut(new PrintStream(
            new BufferedOutputStream(new FileOutputStream(FileDescriptor.out),1<<20),false));
        long t0 = System.nanoTime();
        combinations(10, 25);
        long t1 = System.nanoTime();
        System.out.flush();
        long t2 = System.nanoTime();
        System.out.println((t1 - t0)/1_000_000+" ms");
        System.out.println((t2 - t0)/1_000_000+" ms");
        System.out.flush();
    }
    static void combinations(int start, int end) {
        for(int i = 1, stop = (1 << (end - start)) - 1; i <= stop; i++) {
            System.out.println(
                BitSet.valueOf(new long[]{i}).stream()
                      .mapToObj(b -> String.valueOf(b + start))
                      .collect(Collectors.joining(", ", "[", "]"))
            );
        }
    }
    

    On my machine, it prints something in the order of

    93 ms
    3340 ms
    

    Showing that the code spent more than three seconds on the nonparallelizable printing and only about 100 milliseconds on the calculation. For completeness, the following code goes a level down for the String generation:

    static void combinations(int start, int end) {
        for(int i = 1, stop = (1 << (end - start)) - 1; i <= stop; i++) {
            System.out.println(bits(i, start));
        }
    }
    static String bits(int bits, int offset) {
        StringBuilder sb = new StringBuilder().append('[');
        for(;;) {
            int bit = Integer.lowestOneBit(bits), num = Integer.numberOfTrailingZeros(bit);
            sb.append(num + offset);
            bits -= bit;
            if(bits == 0) break;
            sb.append(", ");
        }
        return sb.append(']').toString();
    }
    

    which halves the calculation time on my machine, while having no noticable impact on the total time, which shouldn’t come as a surprise now.


    But for education purposes, ignoring the lack of potential acceleration, let’s discuss how we would parallelize this operation.

    The sequential code did already bring the task into a form which boils down to an iteration from a start value to an end value. Now, we rewrite this code to a ForkJoinTask (or suitable subclass) which represents an iteration with a start and end value. Then, we add the ability to split this operation into two, by splitting the range in the middle, so we get two tasks iterating over each half of the range. This can be repeated until we decide to have enough potentially parallel jobs and perform the current iteration locally. After the local processing we have to wait for the completion of any task we split off, to ensure that the completion of the root task implies the completion of all subtasks.

    public class Combinations extends RecursiveAction {
        public static void main(String[] args) {
            System.setOut(new PrintStream(new BufferedOutputStream(
                new FileOutputStream(FileDescriptor.out),1<<20),false));
            ForkJoinPool pool = (ForkJoinPool) Executors.newWorkStealingPool();
            long t0 = System.nanoTime();
            Combinations job = Combinations.get(10, 25);
            pool.execute(job);
            job.join();
            long t1 = System.nanoTime();
            System.out.flush();
            long t2 = System.nanoTime();
            System.out.println((t1 - t0)/1_000_000+" ms");
            System.out.println((t2 - t0)/1_000_000+" ms");
            System.out.flush();
        }
    
        public static Combinations get(int min, int max) {
            return new Combinations(min, 1, (1 << (max - min)) - 1);
        }
    
        final int offset, from;
        int to;
    
        private Combinations(int offset, int from, int to) {
            this.offset = offset;
            this.from = from;
            this.to = to;
        }
    
        @Override
        protected void compute() {
            ArrayDeque<Combinations> spawned = new ArrayDeque<>();
            while(getSurplusQueuedTaskCount() < 2) {
                int middle = (from + to) >>> 1;
                if(middle == from) break;
                Combinations forked = new Combinations(offset, middle, to);
                forked.fork();
                spawned.addLast(forked);
                to = middle - 1;
            }
            performLocal();
            for(;;) {
                Combinations forked = spawned.pollLast();
                if(forked == null) break;
                if(forked.tryUnfork()) forked.performLocal(); else forked.join();
            }
        }
    
        private void performLocal() {
            for(int i = from, stop = to; i <= stop; i++) {
                System.out.println(bits(i, offset));
            }
        }
    
        static String bits(int bits, int offset) {
            StringBuilder sb = new StringBuilder().append('[');
            for(;;) {
                int bit=Integer.lowestOneBit(bits), num=Integer.numberOfTrailingZeros(bit);
                sb.append(num + offset);
                bits -= bit;
                if(bits == 0) break;
                sb.append(", ");
            }
            return sb.append(']').toString();
        }
    }
    

    The getSurplusQueuedTaskCount() provides us with a hint about the saturation of the worker threads, in other words, whether forking more jobs might be beneficial. The returned number is compared with a threshold that is typically a small number, the more heterogeneous the jobs and hence, the expected workload, the higher should be the threshold to allow more work-stealing when jobs complete earlier than others. In our case, the workload is expected to be very well balanced.

    There are two ways of splitting. Examples often create two or more forked subtasks, followed by joining them. This may lead to a large number of tasks just waiting for others. The alternative is to fork a subtask and alter the current task, to represent the other. Here, the forked task represents the [middle, to] range whereas the current task is modified to represent the [from, middle] range.

    After forking enough tasks, the remaining range is processed locally in the current thread. Then, the task will wait for all forked subtasks, with one optimization: it will try to unfork the subtasks, to process them locally if no other worker thread has stolen them yet.

    This works smoothly, but unfortunately, as expected, it does not accelerate the operation, as the most expensive part is the printing.


    ¹ Using an int to represent all combinations reduces the supported range length to 31, but keep in mind that such a range length implies 2³¹ - 1 combinations, which is quite a lot to iterate over. If that still feels like a limitation, you may change the code to use long instead. The then-supported range length of 63, in other words 2⁶³ - 1 combinations, is enough to keep to computer busy until the end of the universe.