Search code examples
javaalgorithmgraphjgrapht

More efficient way to prune a directed acyclic graph in jgrapht?


I am looking for a more efficient way to prune a directed acyclic graph (a DAG) constructed in jgrapht.

The DAG represents the relationships between a set of network conversations in time. The parents of a conversation are any conversations which completed before the child conversation started. Constructing the DAG is relatively straightforward, but there are a lot of unnecessary relationships. For efficiency, I want to prune the DAG so each child has a direct relationship to the minimal number of parents (or conversely, so each parent has the minimal number of immediate children).

The prune implementation I am using now (shown below) is based on code found in streme. It works for all of my manually constructed unit test scenarios. However, in a real data set, it is often fairly slow. I ran across a scenario today with 215 vertices but over 22,000 edges. Pruning that DAG took almost 8 minutes of clock time on server-class hardware -- tolerable for my immediate use case, but too slow to scale for larger scenarios.

I believe my problem is similar to the one described in What algorithm can I apply to this DAG? and Algorithm for Finding Redundant Edges in a Graph or Tree. That is, I need to find the transitive reduction or the minimal representation for my DAG. jgrapht does not appear to contain a direct implementation of transitive reduction for a DAG, only transitive closure.

I am looking for suggestions about how to improve the efficiency of the implementation below, or perhaps a pointer to an existing implementation of transitive reduction for jgrapht that I could use instead.

Note: Alternately, if there is a different graphing library for Java that includes a native implementation of transitive reduction, I could switch to that library. My use of jgrapht is confined to a single 200-line class, so swapping it out should not be difficult as long as the interface is similar. To maintain the class interface (persisted to a database), I need a DAG implementation that provides a way to get the parents and children of a given node -- similar to jgrapht's Graphs.predecessorListOf() and Graphs.successorListOf().

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.jgrapht.DirectedGraph;
import org.jgrapht.Graphs;
import org.jgrapht.alg.DijkstraShortestPath;

public static <V, E> void prune(DirectedGraph<V, E> dag) {
   Deque<V> todo = new ArrayDeque<V>(dag.vertexSet());
   Set<V> seen = new HashSet<V>();
   while (!todo.isEmpty()) {
       V v = todo.pop();
       if (seen.contains(v)) {
           continue;
       }
       seen.add(v);
       List<V> targets = Graphs.successorListOf(dag, v);
       for (int i = 0; i < targets.size(); i++) {
           for (int j = i; j < targets.size(); j++) {
               V vi = targets.get(i);
               V vj = targets.get(j);
               List<E> path = DijkstraShortestPath.findPathBetween(dag, vi, vj);
               if (path != null && !path.isEmpty()) {
                   E edge = dag.getEdge(v, vj);
                   dag.removeEdge(edge);
               }
           }
       }
   }
}

Solution

  • Optimized Implementation

    Below is the optimized implementation with the cache, as mentioned above in my first comment.

    import java.util.ArrayDeque;
    import java.util.Deque;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    import org.jgrapht.Graphs;
    import org.jgrapht.experimental.dag.DirectedAcyclicGraph;
    import org.jgrapht.traverse.BreadthFirstIterator;
    
    /**
     * A class to compute transitive reduction for a jgrapht DAG.
     * The basis for this implementation is streme (URL below), but I have made a variety of changes.
     * It assumed that each vertex of type V has a toString() method which uniquely identifies it.
     * @see <a href="https://code.google.com/p/streme/source/browse/streme/src/streme/lang/ast/analysis/ipda/DependencyGraphParallelizer.java">streme</a>
     * @see <a href="http://en.wikipedia.org/wiki/Transitive_reduction">Transitive Reduction</a>
     * @see <a href="http://en.wikipedia.org/wiki/Dijkstra's_algorithm">Dijkstra's Algorithm</a>
     * @see <a href="http://en.wikipedia.org/wiki/Breadth-first_search">Breadth-First Search</a>
     */
    public class TransitiveReduction {
    
        /**
         * Compute transitive reduction for a DAG.
         * Each vertex is assumed to have a toString() method which uniquely identifies it.
         * @param graph   Graph to compute transitive reduction for
         */
        public static <V, E> void prune(DirectedAcyclicGraph<V, E> graph) {
            ConnectionCache<V, E> cache = new ConnectionCache<V, E>(graph);
            Deque<V> deque = new ArrayDeque<V>(graph.vertexSet());
            while (!deque.isEmpty()) {
                V vertex = deque.pop();
                prune(graph, vertex, cache);
            }
        }
    
        /** Prune a particular vertex in a DAG, using the passed-in cache. */
        private static <V, E> void prune(DirectedAcyclicGraph<V, E> graph, V vertex, ConnectionCache<V, E> cache) {
            List<V> targets = Graphs.successorListOf(graph, vertex);
            for (int i = 0; i < targets.size(); i++) {
                for (int j = i + 1; j < targets.size(); j++) {
                    V child1 = targets.get(i);
                    V child2 = targets.get(j);
                    if (cache.isConnected(child1, child2)) {
                        E edge = graph.getEdge(vertex, child2);
                        graph.removeEdge(edge);
                    }
                }
            }
        }
    
        /** A cache that stores previously-computed connections between vertices. */
        private static class ConnectionCache<V, E> {
            private DirectedAcyclicGraph<V, E> graph;
            private Map<String, Boolean> map;
    
            public ConnectionCache(DirectedAcyclicGraph<V, E> graph) {
                this.graph = graph;
                this.map = new HashMap<String, Boolean>(graph.edgeSet().size());
            }
    
            public boolean isConnected(V startVertex, V endVertex) {
                String key = startVertex.toString() + "-" + endVertex.toString();
    
                if (!this.map.containsKey(key)) {
                    boolean connected = isConnected(this.graph, startVertex, endVertex);
                    this.map.put(key, connected);
                }
    
                return this.map.get(key);
            }
    
            private static <V, E> boolean isConnected(DirectedAcyclicGraph<V, E> graph, V startVertex, V endVertex) {
                BreadthFirstIterator<V, E> iter = new BreadthFirstIterator<V, E>(graph, startVertex);
    
                while (iter.hasNext()) {
                    V vertex = iter.next();
                    if (vertex.equals(endVertex)) {
                        return true;
                    }
                }
    
                return false;
            }
        }
    
    }
    

    Improvements

    Among other minor changes, I improved the streme implementation by adding a cache, so we would not need to recompute the path between two vertices that have been seen before. I also changed the streme implementation to use a BreadthFirstIterator to check for connections between nodes, rather than relying on Dijkstra's Algorithm. Dijkstra's Algorithm computes the shortest path, but all we care about here is whether any path exists. Short-circuiting the check makes this implementation quite a bit more efficient than the original.

    Other Potential Improvements

    This implementation can be quite slow for large DAGs, especially where the average vertex has a lot of children. There are two reasons for this: the efficiency of the algorithm itself, and the implementation of the connection cache. The algorithm scales as O(vc2bd) where v is the number of vertices, c is the number of children tied to the average vertex, b is the breadth of the DAG at the average vertex, and d is the depth of the DAG at the average vertex. The cache is a simple HashMap that tracks whether a path exists between two DAG vertices. Adding the cache got me a 14-20x performance improvement vs. the original non-cache implementation. However, as the DAG grows larger, the overhead related to the cache sometimes starts to become significant.

    If you are still having problems with performance, one way to solve that problem might be to incrementally prune the DAG, rather than waiting until all relationships have been added. Depending on the relationships in your DAG, this could help by reducing the average number of children and minimizing the required size of the connection cache. In my most recent test (4500 vertices), I was able to get a substantial improvement by pruning the DAG after each group of 10-15 vertices were added. Along with the other improvements to this algorithm, pruning incrementally resulted in a reduction from 4-6 hours of processing time down to ~10 minutes.

    Testing and Validation

    I have unit tests around this and I'm fairly confident that it works as expected, but I'm very willing to investigate potential problems with the algorithm. To that end, I added tests specifically for cthiebaud's scenario, just in case I somehow missed a corner case in my other testing.

    Below is a visualization of the result. The left-hand graph is the original, and the right-hand graph is after pruning. These pictures were generated by rendering DOT output from jgrapht's DOTExporter.

    Sample DAG prune results

    This is what I would have expected, so I still think the implementation is working properly.