Search code examples
javamultithreadingthreadpooldijkstrathreadpoolexecutor

Java: Wait in a loop until tasks of ThreadPoolExecutor are done before continuing


I'm working on making the Dijkstra algorithm parallel. Per node threads are made to look at all the edges of the current node. This was made parallel with threads but there is too much overhead. This resulted in a longer time than the sequential version of the algorithm.

ThreadPool was added to solve this problem but i'm having trouble with waiting until the tasks are done before I can move on to the next iteration. Only after all tasks for one node is done we should move on. We need the results of all tasks before I can search for the next closest by node.

I tried doing executor.shutdown() but with this aproach it won't accept new tasks. How can we wait in the loop until every task is finished without having to declare the ThreadPoolExecutor every time. Doing this will defeat the purpose of the less overhead by using this instead of regular threads.

One thing I thought about was an BlockingQueue that add the tasks(edges). But also for this solution i'm stuck on waiting for tasks to finish without shudown().

public void apply(int numberOfThreads) {
        ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numberOfThreads);

        class DijkstraTask implements Runnable {

            private String name;

            public DijkstraTask(String name) {
                this.name = name;
            }

            public String getName() {
                return name;
            }

            @Override
            public void run() {
                calculateShortestDistances(numberOfThreads);
            }
        }

        // Visit every node, in order of stored distance
        for (int i = 0; i < this.nodes.length; i++) {

            //Add task for each node
            for (int t = 0; t < numberOfThreads; t++) {
                executor.execute(new DijkstraTask("Task " + t));
            }

            //Wait until finished?
            while (executor.getActiveCount() > 0) {
                System.out.println("Active count: " + executor.getActiveCount());
            }

            //Look through the results of the tasks and get the next node that is closest by
            currentNode = getNodeShortestDistanced();

            //Reset the threadCounter for next iteration
            this.setCount(0);
        }
    }

The amount of edges is divided by the number of threads. So 8 edges and 2 threads means each thread will deal with 4 edges in parallel.

public void calculateShortestDistances(int numberOfThreads) {

        int threadCounter = this.getCount();
        this.setCount(count + 1);

        // Loop round the edges that are joined to the current node
        currentNodeEdges = this.nodes[currentNode].getEdges();

        int edgesPerThread = currentNodeEdges.size() / numberOfThreads;
        int modulo = currentNodeEdges.size() % numberOfThreads;
        this.nodes[0].setDistanceFromSource(0);
        //Process the edges per thread
        for (int joinedEdge = (edgesPerThread * threadCounter); joinedEdge < (edgesPerThread * (threadCounter + 1)); joinedEdge++) {

            System.out.println("Start: " + (edgesPerThread * threadCounter) + ". End: " + (edgesPerThread * (threadCounter + 1) + ".JoinedEdge: " + joinedEdge) + ". Total: " + currentNodeEdges.size());
            // Determine the joined edge neighbour of the current node
            int neighbourIndex = currentNodeEdges.get(joinedEdge).getNeighbourIndex(currentNode);

            // Only interested in an unvisited neighbour
            if (!this.nodes[neighbourIndex].isVisited()) {
                // Calculate the tentative distance for the neighbour
                int tentative = this.nodes[currentNode].getDistanceFromSource() + currentNodeEdges.get(joinedEdge).getLength();
                // Overwrite if the tentative distance is less than what's currently stored
                if (tentative < nodes[neighbourIndex].getDistanceFromSource()) {
                    nodes[neighbourIndex].setDistanceFromSource(tentative);
                }
            }
        }

        //if we have a modulo above 0, the last thread will process the remaining edges
        if (modulo > 0 && numberOfThreads == (threadCounter + 1)) {
            for (int joinedEdge = (edgesPerThread * threadCounter); joinedEdge < (edgesPerThread * (threadCounter) + modulo); joinedEdge++) {
                // Determine the joined edge neighbour of the current node
                int neighbourIndex = currentNodeEdges.get(joinedEdge).getNeighbourIndex(currentNode);

                // Only interested in an unvisited neighbour
                if (!this.nodes[neighbourIndex].isVisited()) {
                    // Calculate the tentative distance for the neighbour
                    int tentative = this.nodes[currentNode].getDistanceFromSource() + currentNodeEdges.get(joinedEdge).getLength();
                    // Overwrite if the tentative distance is less than what's currently stored
                    if (tentative < nodes[neighbourIndex].getDistanceFromSource()) {
                        nodes[neighbourIndex].setDistanceFromSource(tentative);
                    }
                }
            }
        }
        // All neighbours are checked so this node is now visited
        nodes[currentNode].setVisited(true);
    }

Thanks for helping me!


Solution

  • You should look into CyclicBarrier or a CountDownLatch. Both of these allow you to prevent threads starting unless other threads have signaled that they're done. The difference between them is that CyclicBarrier is reusable, i.e. can be used multiple times, while CountDownLatch is one-shot, you cannot reset the count.

    Paraphrasing from the Javadocs:

    A CountDownLatch is a synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes.

    A CyclicBarrier is a synchronization aid that allows a set of threads to all wait for each other to reach a common barrier point. CyclicBarriers are useful in programs involving a fixed sized party of threads that must occasionally wait for each other. The barrier is called cyclic because it can be re-used after the waiting threads are released.

    https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/util/concurrent/CyclicBarrier.html

    https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/util/concurrent/CountDownLatch.html