Search code examples
javaalgorithmtreegraph-theory

Minimum edges to form path with length L


I came across this problem.

Given a weighted tree T, find the minimum number of edges to form a simple path (no duplicate vertices or edges) of weight (sum of weights of edges) exactly L.


More details:

L is given as input and it can be different for each case. There are N vertices in the tree numbered from 0 to N - 1.

My first thought was the best I can do is go over all the N^2 paths in T. Here is a runnable code with example input.

import java.util.*;
class Edge {
    int toVertex, weight;
    Edge(int v, int w) {
        toVertex = v; weight = w;
    }
}
class Solver {
    // method called with the tree T given as adjacency list and the path length L to achieve
    // method to return minimum edges to create path of length L or -1 if impossible
    public static int solve(List<List<Edge>> T, long L) {
        int min = (int) 1e9;
        for (int i = 0; i < T.size(); i++) {
            min = Math.min(min, test(T, L, i, -1, 0, 0));
        }
        if (min == (int) 1e9) {
            return -1;
        } else {
            return min;
        }
    }
    static int test(List<List<Edge>> T, long L, int vertex, int parent, long length, int edges) {
        if (length == L) {
            return edges;
        } else if (length < L) {
            int min = (int) 1e9;
            for (Edge edge : T.get(vertex)) {
                if (edge.toVertex != parent) {
                    min = Math.min(min, test(T, L, edge.toVertex, vertex, length + edge.weight, edges + 1));
                }
            }
            return min;
        } else {
            return (int) 1e9; // overshoot
        }
    }
}
// provided code
public class Main {
    static void putEdge(List<List<Edge>> T, int vertex1, int vertex2, int weight) {
        T.get(vertex1).add(new Edge(vertex2, weight));
        T.get(vertex2).add(new Edge(vertex1, weight));
    }
    public static void main(String[] args) {
        // example input
        List<List<Edge>> T = new ArrayList<List<Edge>>();
        int N = 8;
        for (int i = 0; i < N; i++) T.add(new ArrayList<Edge>());
        putEdge(T, 0, 1, 2);
        putEdge(T, 1, 2, 1);
        putEdge(T, 1, 3, 2);
        putEdge(T, 2, 6, 1);
        putEdge(T, 6, 7, 1);
        putEdge(T, 3, 4, 1);
        putEdge(T, 3, 5, 4);
        System.out.println(Solver.solve(T, 5L)); // path from 4 to 5 have 2 edges and length 5
    }
}

But this exceeds time limit when N reaches around 10,000. I also considered binary search on the answer, but checking a particular answer is possible looks just as hard as solving the original problem.

Is there a more efficient way to solve this to somehow avoid testing all paths?


Solution

  • There are two main ways to solve this problem; I will describe the simpler method.

    To start, root the tree arbitrarily (vertex 0 tends to be a good choice as it always exists). Let dist[x] denote the sum of weights of edges on the path from the root to x and let depth[x] denote the number of edges on this path. For any two distinct nodes u and v, there is one unique simple path between them which goes from u to the lowest common ancestor (LCA) of the two nodes, then to v. We can express the total weight on this path as dist[u] + dist[v] - 2 * dist[LCA(u, v)] since the edges from the root to the LCA are counted in both dist[u] and dist[v]. Similarly, the number of edges on the path is depth[u] + depth[v] - 2 * depth[LCA(u, v)].

    Next, let's consider every node n as a possible LCA on a path between two nodes of weight L. In this case, these two nodes must both be in the subtree of n (including itself). To compute the optimal answer at each node, we will store a map for each node that associates each possible dist[x] with the minimum depth[x] that can achieve that distance for any x in the subtree of n.

    To process a node n, we iterate over each child c and compute this map for it first and then combine that result into the map of the current node in two stages. For each (d, e) key-value pair in the child's subtree, we check the map of the current node n for a key k that satisfies k + d - 2 * dist[n] = L. If it exists, we have found a path of weight L. Now we can update our answer with the minimum of the current answer and the sum of the number of edges for the two parts of this path. After performing all necessary updates of the answer with the subtree of c, we update the map for node n with the map for c, maintaining the minimum number of edges for all distances seen so far (in order to make sure we can find optimal paths from the map when considering later subtrees).

    To update these maps efficiently, we will choose to always update the map with more keys using the map with less keys. In the worst case, all nodes have distinct distances from the root. Whenever a particular's node distance is added from one map to another, the resulting map must be at least twice as large as the size of the old map. The size of any map cannot exceed N keys, so each element can only be added to at most log N maps. Each node contributes to log N updates, so the time and space complexity are both O(N log N).

    public static int solve(List<List<Edge>> T, long L) {
        var minEdgesForDist = Stream.<Map<Long, Integer>>generate(HashMap::new).limit(T.size()).collect(Collectors.toCollection(ArrayList::new));
        return new Object() { 
            // creating new object to define methods inside the context of the solve method
            int dfs(int node, int par, int depth, long dist) {
                minEdgesForDist.get(node).put(dist, depth); // for node itself
                int ret = Integer.MAX_VALUE;
                for (var edge : T.get(node))
                    if (edge.toVertex != par) {
                        ret = Math.min(ret, dfs(edge.toVertex, node, depth + 1, dist + edge.weight));
                        if (minEdgesForDist.get(edge.toVertex).size() > minEdgesForDist.get(node).size())
                            Collections.swap(minEdgesForDist, edge.toVertex, node); // important!
                        for (var entry : minEdgesForDist.get(edge.toVertex).entrySet()) {
                            var other = minEdgesForDist.get(node).get(L + 2 * dist - entry.getKey());
                            if (other != null)
                                ret = Math.min(ret, entry.getValue() + other - 2 * depth);
                        }
                        for (var entry : minEdgesForDist.get(edge.toVertex).entrySet())
                            minEdgesForDist.get(node).merge(entry.getKey(), entry.getValue(), Math::min);
                    }
                return ret;
            }
            
            int getAnswer() {
                int ret = dfs(0, 0, 0, 0);
                return ret != Integer.MAX_VALUE ? ret : -1;
            }
        }.getAnswer();
    }