Search code examples
pythonalgorithmdata-structuresshortest-pathheap

An algorithm to find the shortest path based on 2 criteria


We start on node 0 and need to get to node n-1 while using as less steps as possible. At the same time each step affects our temperature, some steps add 1 degree and some subtract 1.

The input is in this format -> arr = [4, 3, [[0, 1, -1], [1, 3, -1], [1, 2, 1]]]

arr[0] 4 is n -> number of all the possible nodes, where first node is 0 and last node is 3, in this case we would want to get from node 0 to node 3 (n-1)

arr[1] 3 is m -> number of all the available steps (the following lists)

arr[2] [[0, 1, -1], [1, 3, -1], [1, 2, 1]] -> arr with length m where each element is a list with elements u, v, c with info about the step -> u!=v (u(node) is the start of the step and v(node) is the end), c is either 1 or -1 -> this tells us if the steps subtracts temperature or adds it.

The goal is to use as less steps as possible and have the closest temperature to 0 as possible. Shorter path is more important than temperature, so if we have a choice between path_length=1 temperature=20 and path_length=2 temperature=0, we choose the one with the smaller path lenght -> path_length=1

if there isnt valid path, we just print out ajajaj

input/output examples:

input: arr = [4, 3, [[0, 1, -1], [1, 3, -1], [1, 2, 1]]]:

output: (-2, 2, [0, 1]) where -2 is the temperature, 2 is the length of the path and the list is the paths used -> we used [0, 1, -1] and [1, 3, -1] to get to n-1

input: [4, 4, [[0, 1, 1], [0, 2, -1], [1, 3, -1], [2, 3, 1]]]

output: (0, 2, [1, 3])

input: [3, 1, [[0, 1, 1]]]

output: ajajaj

input: [5, 5, [[0, 1, 1], [1, 2, -1], [1, 2, 1], [2, 3, -1], [3, 4, -1]]]

output: [(0, 4, [0, 2, 3, 4])]

This is my solution which passes the test cases and 70% of the big input cases:

from heapq import heappush, heappop
from typing import Union


def main(entry: list) -> Union[str, tuple]:
    results = []
    # n == num of possible nodes -> if n=4 then possible nodes are 0, 1, 2, 3
    # m == num of all paths -> yet again they start at 0, so if m=4, the first path is 0 and the last one is 3
    n, m = entry.pop(0), entry.pop(0)

    # loop trough the paths and store them in a graph
    graph = [[] for _ in range(n)]
    paths = entry.pop()
    for i in range(m):
        # u == start of the path, v == end of the path, c == type of the path
        # u != v, c in {'ohniva', 'ledova'} -> {+1, -1}
        u, v, c = paths[i]
        graph[u].append((v, c))

    pq = [(0, 0, [])]
    visited = set()
    max_path = float("inf")
    while pq:
        temp, node, path = heappop(pq)
        # NOTE this probably isnt correct
        if len(path) > max_path:
            continue
        if node == n - 1:
            if len(path) < max_path:
                max_path = len(path)
            result = (temp, len(path), path)
            results.append(result)
            visited = set()

        if node not in visited:
            visited.add(node)
            for end, c in graph[node]:
                new_temp = temp + c  # +/- 1
                new_node = None
                # NOTE this probably isnt correct
                for i, p in enumerate(paths):
                    if p == [node, end, c]:
                        new_node = i
                        break

                new_path = path + [new_node]
                heappush(pq, (new_temp, end, new_path))

    if n - 1 not in visited:
        return "ajajaj"

    if len(results) > 1:
        return sorted(results, key=lambda x: (x[1], abs(x[0])))[0]
    else:
        return results[0]

Is there any bug in the algorithm or is the whole approach wrong? I've been trying to figure it out for the past 4hours...

EDIT: example of a case which doesn't pass:

input:

[93, 289, [[62, 31, -1], [45, 27, 1], [11, 7, 1], [80, 74, 1], [15, 82, 1], [56, 12, 1], [49, 85, 1], [61, 21, -1], [90, 35, 1], [13, 68, 1], [7, 83, -1], [65, 68, -1], [44, 74, -1], [48, 59, 1], [39, 45, 1], [1, 82, 1], [32, 62, -1], [72, 82, 1], [27, 23, -1], [27, 73, 1], [69, 35, -1], [24, 77, 1], [8, 66, 1], [68, 8, -1], [14, 61, 1], [80, 76, 1], [82, 8, 1], [76, 61, -1], [48, 53, -1], [90, 33, 1], [11, 86, 1], [52, 42, -1], [46, 36, 1], [26, 69, 1], [46, 64, -1], [0, 14, -1], [31, 60, 1], [88, 11, -1], [28, 60, -1], [73, 78, 1], [52, 2, 1], [23, 82, 1], [63, 92, 1], [21, 84, -1], [80, 7, 1], [91, 49, -1], [62, 65, -1], [92, 16, -1], [13, 59, -1], [14, 40, 1], [58, 86, 1], [6, 60, 1], [21, 59, 1], [68, 12, 1], [92, 75, -1], [83, 36, -1], [90, 60, -1], [2, 84, -1], [22, 50, 1], [72, 21, -1], [47, 3, 1], [51, 9, -1], [67, 77, -1], [92, 10, -1], [80, 20, 1], [55, 5, -1], [46, 64, -1], [22, 15, 1], [87, 24, -1], [80, 71, 1], [61, 39, 1], [59, 83, 1], [60, 63, 1], [10, 12, 1], [70, 75, -1], [33, 27, -1], [75, 14, 1], [52, 4, -1], [45, 61, 1], [59, 55, -1], [30, 37, 1], [10, 38, 1], [56, 4, -1], [51, 39, -1], [35, 3, 1], [49, 37, -1], [40, 6, -1], [47, 90, -1], [20, 68, -1], [74, 38, 1], [88, 18, -1], [25, 0, 1], [51, 73, -1], [75, 91, 1], [14, 75, -1], [86, 73, -1], [52, 21, -1], [44, 89, 1], [68, 80, -1], [82, 37, 1], [59, 78, 1], [48, 43, -1], [47, 88, -1], [77, 60, -1], [32, 22, 1], [55, 6, 1], [49, 77, -1], [27, 48, 1], [46, 31, -1], [65, 57, 1], [83, 11, 1], [68, 84, 1], [29, 27, 1], [87, 59, 1], [75, 41, -1], [46, 44, -1], [67, 29, -1], [75, 55, -1], [10, 19, 1], [52, 46, 1], [12, 20, -1], [0, 4, 1], [39, 27, -1], [12, 28, -1], [1, 61, 1], [79, 34, -1], [45, 79, -1], [13, 86, 1], [20, 74, -1], [35, 60, -1], [10, 89, -1], [70, 44, 1], [14, 3, -1], [81, 7, 1], [3, 78, -1], [52, 6, -1], [62, 73, -1], [0, 34, 1], [2, 45, -1], [50, 25, 1], [73, 63, -1], [70, 92, -1], [64, 80, 1], [66, 53, -1], [35, 7, 1], [0, 84, 1], [85, 14, 1], [2, 42, 1], [26, 80, -1], [18, 24, -1], [31, 86, 1], [78, 45, -1], [21, 66, -1], [61, 57, -1], [46, 49, -1], [19, 82, -1], [55, 30, -1], [1, 6, -1], [29, 33, 1], [44, 45, 1], [66, 91, 1], [42, 58, 1], [56, 26, -1], [36, 48, 1], [32, 41, 1], [12, 90, -1], [92, 24, -1], [76, 47, 1], [47, 25, -1], [90, 36, -1], [22, 37, -1], [70, 57, 1], [51, 31, 1], [32, 13, -1], [39, 10, -1], [13, 36, 1], [67, 50, 1], [13, 24, 1], [11, 12, 1], [26, 51, -1], [54, 47, -1], [19, 43, 1], [76, 88, 1], [40, 39, -1], [75, 91, -1], [31, 92, 1], [36, 13, -1], [51, 47, -1], [14, 1, 1], [92, 17, -1], [87, 79, -1], [16, 9, -1], [17, 84, -1], [69, 43, -1], [33, 5, 1], [23, 17, -1], [20, 49, 1], [0, 61, 1], [25, 9, 1], [77, 12, 1], [80, 44, -1], [52, 23, 1], [18, 1, 1], [75, 50, -1], [86, 92, -1], [52, 6, -1], [37, 51, 1], [20, 91, 1], [7, 85, -1], [76, 48, 1], [70, 11, 1], [78, 75, 1], [57, 16, -1], [62, 31, -1], [29, 3, -1], [34, 79, -1], [50, 71, 1], [6, 90, 1], [13, 77, -1], [62, 54, 1], [24, 38, -1], [49, 7, 1], [88, 82, -1], [73, 8, -1], [13, 17, 1], [78, 66, 1], [75, 6, -1], [9, 51, -1], [31, 58, -1], [76, 74, -1], [54, 34, -1], [85, 59, -1], [69, 68, -1], [33, 67, 1], [36, 55, -1], [92, 23, 1], [28, 65, 1], [48, 24, -1], [2, 71, -1], [53, 59, -1], [78, 61, 1], [82, 79, 1], [91, 58, 1], [82, 76, -1], [61, 70, -1], [92, 10, -1], [4, 26, -1], [76, 86, 1], [24, 20, 1], [41, 59, -1], [44, 46, -1], [64, 33, -1], [84, 14, 1], [54, 1, -1], [21, 82, 1], [77, 8, 1], [10, 40, 1], [74, 33, 1], [77, 15, -1], [57, 78, 1], [24, 26, -1], [36, 2, -1], [74, 87, -1], [83, 90, 1], [63, 49, -1], [12, 91, 1], [54, 36, -1], [72, 26, -1], [73, 36, 1], [35, 2, 1], [70, 72, 1], [73, 26, 1], [76, 23, -1], [59, 69, -1], [27, 5, -1], [87, 24, 1], [61, 84, -1], [77, 33, 1], [63, 68, -1], [87, 36, -1], [20, 77, -1], [31, 11, -1], [90, 63, -1], [51, 62, 1], [91, 77, 1], [13, 7, -1], [18, 55, -1], [75, 33, -1], [56, 74, -1]]]

my output:

(-2, 4, [35, 24, 244, 141])

correct output:

(-1, 3, [35, 76, 54])

EDIT 2

@Timelesses approach, wrong i/o cases:

  1. Path is larger than shortest:
[29, 52, [[22, 11, 1], [28, 11, -1], [27, 6, -1], [1, 27, 1], [18, 1, -1], [25, 3, 1], [15, 20, 1], [15, 7, 1], [10, 27, -1], [28, 13, 1], [5, 4, -1], [0, 24, -1], [10, 7, -1], [6, 10, 1], [21, 0, 1], [13, 2, -1], [19, 13, -1], [20, 5, 1], [1, 11, 1], [3, 18, 1], [18, 8, 1], [0, 4, 1], [22, 6, -1], [16, 19, -1], [7, 1, -1], [20, 5, -1], [19, 22, 1], [1, 28, 1], [4, 17, -1], [17, 26, 1], [14, 11, -1], [14, 19, 1], [5, 1, -1], [11, 20, 1], [12, 25, -1], [3, 18, -1], [11, 15, -1], [7, 17, -1], [26, 4, -1], [6, 12, -1], [8, 21, 1], [28, 18, -1], [14, 16, 1], [15, 21, 1], [15, 23, -1], [6, 17, 1], [7, 8, -1], [17, 22, 1], [9, 8, 1], [7, 6, -1], [3, 26, 1], [18, 17, -1]]]

returns

(0, 8, [21, 28, 47, 22, 13, 12, 24, 27])

when shorter path exists:

(0, 4, [21, 10, 32, 37])
  1. it returns ajajaj when solution exists:
[31, 72, [[14, 21, 1], [0, 13, 1], [27, 20, -1], [20, 10, -1], [1, 22, 1], [11, 10, -1], [1, 23, 1], [25, 19, -1], [18, 19, 1], [10, 15, 1], [20, 4, 1], [15, 20, -1], [19, 16, 1], [27, 24, -1], [30, 19, -1], [28, 29, -1], [26, 0, 1], [9, 24, 1], [15, 16, -1], [19, 30, 1], [12, 0, 1], [18, 3, 1], [15, 22, 1], [26, 11, 1], [5, 14, 1], [18, 12, -1], [5, 12, 1], [15, 23, -1], [10, 23, 1], [5, 19, 1], [19, 18, 1], [28, 6, -1], [3, 23, -1], [12, 24, -1], [9, 0, -1], [2, 7, -1], [22, 0, 1], [15, 19, 1], [23, 2, 1], [9, 29, -1], [28, 27, 1], [3, 26, 1], [27, 29, -1], [17, 27, 1], [29, 3, 1], [16, 12, -1], [22, 27, -1], [12, 1, -1], [8, 17, 1], [29, 14, 1], [26, 10, 1], [16, 1, -1], [19, 26, 1], [29, 6, -1], [20, 27, 1], [23, 0, 1], [14, 29, -1], [23, 12, 1], [6, 2, -1], [26, 18, -1], [30, 5, 1], [30, 22, 1], [17, 30, -1], [14, 10, -1], [17, 10, -1], [6, 8, -1], [6, 21, 1], [3, 10, -1], [25, 27, -1], [4, 26, 1], [12, 17, 1], [12, 20, 1]]]

should return:

(2, 2, [36, 61])

Solution

  • If is an option :

    import networkx as nx
    
    def sp_on_2c(arr, source=0):
        n, m, steps = arr
        
        DG = nx.DiGraph()
        for idx, (u, v, c) in enumerate(steps):
            DG.add_edge(u, v, weight=c, indices=idx)
    
        try:
            # this can be replaced with `shortest_path` but test2 has two !
            sp = list(nx.all_shortest_paths(DG, source=source, target=n-1))[-1]
            pg = nx.path_graph(sp).edges()
    
            return (
                sum(DG[u][v]["weight"] for u,v in pg),
                len(sp) - 1, # is it always the case ?
                [DG[u][v]["indices"] for u,v in pg],
            )
        
        except (nx.NodeNotFound, nx.NetworkXNoPath):
            return "ajajaj"
    

    Output :

    for idx, test in enumerate([test1, test2, test3, test4]):
        print(f"test{idx+1} >>", sp_on_2c(test))
        
    # test1 >> (-2, 2, [0, 1])
    # test2 >> (0, 2, [1, 3])
    # test3 >> ajajaj
    # test4 >> (0, 4, [0, 2, 3, 4])
    

    The graphs :

    enter image description here

    Used input :

    test1 = [4, 3, [[0, 1, -1], [1, 3, -1], [1, 2, 1]]]
    test2 = [4, 4, [[0, 1, 1], [0, 2, -1], [1, 3, -1], [2, 3, 1]]]
    test3 = [3, 1, [[0, 1, 1]]]
    test4 = [5, 5, [[0, 1, 1], [1, 2, -1], [1, 2, 1], [2, 3, -1], [3, 4, -1]]]