Search code examples
javaalgorithmdata-structuresdijkstraheap

Dijkstras algorithm shortest path in directed graph, find the last node taken to destination vertex


The code below was taken from https://algorithms.tutorialhorizon.com/dijkstras-shortest-path-algorithm-spt-adjacency-list-and-min-heap-java-implementation/

it finds the shortest distance to each vertex from a given source vertex but does not as of yet offer a way to track the path taken. Is there an easy fix to this? if not how should i approach this?


import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.LinkedList;
import java.util.StringTokenizer;

public class DijkstraMinHeap {
    static class Edge {
        int source;
        int destination;
        int weight;

        public Edge(int source, int destination, int weight) {
            this.source = source;
            this.destination = destination;
            this.weight = weight;
        }
    }

    static class HeapNode{
        int vertex;
        int distance;
    }
    static class Graph {
        int vertices;
        LinkedList<Edge>[] adjacencylist;

        Graph(int vertices) {
            this.vertices = vertices;
            adjacencylist = new LinkedList[vertices];
            for (int i = 0; i <vertices ; i++) {
                adjacencylist[i] = new LinkedList<>();
            }
        }

        public void addEdge(int source, int destination, int weight) {
            Edge edge = new Edge(source, destination, weight);
            adjacencylist[source].addFirst(edge);

            /**edge = new Edge(destination, source, weight);
            adjacencylist[destination].addFirst(edge); //for undirected graph*/
        }

        public void dijkstra_GetMinDistances(int sourceVertex){
            int INFINITY = Integer.MAX_VALUE;
            boolean[] SPT = new boolean[vertices];

// //create heapNode for all the vertices
            HeapNode [] heapNodes = new HeapNode[vertices];
            for (int i = 0; i <vertices ; i++) {
                heapNodes[i] = new HeapNode();
                heapNodes[i].vertex = i;
                heapNodes[i].distance = INFINITY;
            }

            //decrease the distance for the first index
            heapNodes[sourceVertex].distance = 0;

            //add all the vertices to the DijkstraMinHeap.MinHeap
            MinHeap minHeap = new MinHeap(vertices);
            for (int i = 0; i <vertices ; i++) {
                minHeap.insert(heapNodes[i]);
            }
            //while minHeap is not empty
            while(!minHeap.isEmpty()){
                //extract the min
                HeapNode extractedNode = minHeap.extractMin();

                //extracted vertex
                int extractedVertex = extractedNode.vertex;
                SPT[extractedVertex] = true;

                //iterate through all the adjacent vertices
                LinkedList<Edge> list = adjacencylist[extractedVertex];
                for (int i = 0; i <list.size() ; i++) {
                    Edge edge = list.get(i);
                    int destination = edge.destination;
                    //only if destination vertex is not present in SPT
                    if(SPT[destination]==false ) {
                        ///check if distance needs an update or not
                        //means check total weight from source to vertex_V is less than
                        //the current distance value, if yes then update the distance
                        int newKey = heapNodes[extractedVertex].distance + edge.weight ;
                        int currentKey = heapNodes[destination].distance;
                        if(currentKey>newKey){
                            decreaseKey(minHeap, newKey, destination);
                            heapNodes[destination].distance = newKey;
                        }
                    }
                }
            }
            //print SPT
            printDijkstra(heapNodes, sourceVertex);
        }

        public void decreaseKey(MinHeap minHeap, int newKey, int vertex){

            //get the index which distance's needs a decrease;
            int index = minHeap.indexes[vertex];

            //get the node and update its value
            HeapNode node = minHeap.mH[index];
            node.distance = newKey;
            minHeap.bubbleUp(index);
        }

        public void printDijkstra(HeapNode[] resultSet, int sourceVertex){
            System.out.println("Dijkstra's Algorithm: (using Adjacency List and Min Heap)");
            for (int i = 0; i <vertices ; i++) {
                String x = String.valueOf(resultSet[i].distance);
                if(resultSet[i].distance == Integer.MAX_VALUE){
                    x = "unreachable";
                }
                System.out.println("Node: " + i + " | predecessor " + resultSet[i].vertex+
                        " | distance: " + x);
            }
        }
    }

    public static void main(String[] args) throws IOException {


        BufferedReader br = new BufferedReader(new FileReader("./vg1.txt"));
        printResult(br);

    }

    public static void printResult(BufferedReader br) throws IOException {
        StringTokenizer st = new StringTokenizer((br.readLine()));
        int vertices = Integer.parseInt(st.nextToken());
        Graph graph = new Graph(vertices);


        int K = Integer.parseInt((st.nextToken()));
        for (int i = 0; i < K; i++) {
            st = new StringTokenizer(br.readLine());
            int from = Integer.parseInt(st.nextToken());
            int to = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());
            graph.addEdge(from, to, weight);
        }
        int source_vertex = 1;
        graph.dijkstra_GetMinDistances(source_vertex);
    }


    static class MinHeap {
        int capacity;
        int currentSize;
        HeapNode[] mH;
        int[] indexes; //will be used to decrease the distance


        public MinHeap(int capacity) {
            this.capacity = capacity;
            mH = new HeapNode[capacity + 1];
            indexes = new int[capacity];
            mH[0] = new HeapNode();
            mH[0].distance = Integer.MIN_VALUE;
            mH[0].vertex = -1;
            currentSize = 0;
        }

        public void display() {
            for (int i = 0; i <= currentSize; i++) {
                System.out.println(" " + mH[i].vertex + " distance " + mH[i].distance);
            }
            System.out.println("________________________");
        }

        public void insert(HeapNode x) {
            currentSize++;
            int idx = currentSize;
            mH[idx] = x;
            indexes[x.vertex] = idx;
            bubbleUp(idx);
        }

        public void bubbleUp(int pos) {
            int parentIdx = pos / 2;
            int currentIdx = pos;
            while (currentIdx > 0 && mH[parentIdx].distance > mH[currentIdx].distance) {
                HeapNode currentNode = mH[currentIdx];
                HeapNode parentNode = mH[parentIdx];

                //swap the positions
                indexes[currentNode.vertex] = parentIdx;
                indexes[parentNode.vertex] = currentIdx;
                swap(currentIdx, parentIdx);
                currentIdx = parentIdx;
                parentIdx = parentIdx / 2;
            }
        }

        public HeapNode extractMin() {
            HeapNode min = mH[1];
            HeapNode lastNode = mH[currentSize];
    // update the indexes[] and move the last node to the top
            indexes[lastNode.vertex] = 1;
            mH[1] = lastNode;
            mH[currentSize] = null;
            sinkDown(1);
            currentSize--;
            return min;
        }

        public void sinkDown(int k) {
            int smallest = k;
            int leftChildIdx = 2 * k;
            int rightChildIdx = 2 * k + 1;
            if (leftChildIdx < heapSize() && mH[smallest].distance > mH[leftChildIdx].distance) {
                smallest = leftChildIdx;
            }
            if (rightChildIdx < heapSize() && mH[smallest].distance > mH[rightChildIdx].distance) {
                smallest = rightChildIdx;
            }
            if (smallest != k) {

                HeapNode smallestNode = mH[smallest];
                HeapNode kNode = mH[k];

                //swap the positions
                indexes[smallestNode.vertex] = k;
                indexes[kNode.vertex] = smallest;
                swap(k, smallest);
                sinkDown(smallest);
            }
        }

        public void swap(int a, int b) {
            HeapNode temp = mH[a];
            mH[a] = mH[b];
            mH[b] = temp;
        }

        public boolean isEmpty() {
            return currentSize == 0;
        }

        public int heapSize() {
            return currentSize;
        }
    }
}

Solution

  • Try to add the Edge[] edgeTo array that will track the the edge that you used to reach specific vertex.

    For example

    edgeTo[0] = null
    edgeTo[1] = Edge( 0 -> 1 )
    edgeto[2] = Edge( 1 -> 2 )
    
    There is the path 0 -> 1 -> 2 and you can find it by traversing backwards.
    

    You can create such an array structure by changing the lines:

    public void dijkstra_GetMinDistances(int sourceVertex){
       int INFINITY = Integer.MAX_VALUE;
       boolean[] SPT = new boolean[vertices];
       Edge[] edgeTo = new Edge[vertices];
       ...      
    

    and you will update this edgeTo array in if statement where you find the shorter distance:

    if(currentKey>newKey){
       edgeTo[destination] = edge;
       decreaseKey(minHeap, newKey, destination);
       heapNodes[destination].distance = newKey;
    }
    

    Finally the method that return the shortest part from source_vertex to any_other_vertex is:

    public List<Edge> pathTo(int vertex) {
        List<Edge> path = new LinkedList<>();
        for (Edge e = edgeTo[vertex]; e != null; e = edgeTo[e.source]) {
            path.addFirst(e);
        }
        return path;
    }
    

    All credits goes to great book and online course https://algs4.cs.princeton.edu/home/