Search code examples
python-3.xprims-algorithm

Python - Prim's Algorithm Implementation with Array


I'm trying to implement Prim's Algorithm with Python 3 that counts the total weight of the MST it generates. And I'm doing something unusual, using an "array" to keep track of unvisited nodes.

Here's my code:

def Prim(Graph):
    # row 1 is "still in R"
    # row 2 is the connector vertex
    # row 3 is the cost
    total = 0
    A = []
    n = len(Graph)
    A = [[None for x in range(0, n)] for y in range(1, 4)]
    #Debugging purposes
    #print(A)
    for x in range(1, n):
        A[0][x] = 'Y'
        A[1][x] = 0
        A[2][x] = 0

    for neighbour in Graph[1]: 
        A[1][neighbour-1] = 1
        A[2][neighbour-1] = Graph[1][neighbour]
        #Debugging purposes
        #print("Neighbour: ", neighbour, "Weight: ", Graph[1][neighbour])
    current = 1
    T = [current]
    MST_edges = {}
    count = 0
    while len(T) < n:
        x = search_min(current, A)
        T.append(x)
        MST_edges[x] = A[1][x]
        A[0][x] = 'N'
        total += A[2][x]

        #print(Graph)
        #print(A)
        for neighbour in Graph[x]:
            #print(neighbour)
            #print(A[2][neighbour-1])
            if A[0][neighbour-1] != 'N':
                if Graph[x][neighbour] < A[2][neighbour-1]:
                    A[1][neighbour-1] = x
                    A[2][neighbour-1] = Graph[x][neighbour]
        count += 1
        current = T[count]
    return total



def search_min(current, A):
    minimum_cost = 100
    minimum_vertex = 1
    for x in range(1,len(A[0])):
        if A[1][x] != None and A[0][x] != 'N' and A[2][x] < minimum_cost:
                minimum_cost = A[2][x]
                minimum_vertex = x
                #Debugging
    ##            print("x", x)
    ##            print("cost",minimum_cost)
    ##            print("vertex",x)
    return minimum_vertex

It sometimes gives me ridiculously low weights like 20 (which is next to impossible since all edges' minimum weight is 10). The problem is probably in the while loop:

 while len(T) < n:
        x = search_min(current, A)
        T.append(x)
        MST_edges[x] = A[1][x]
        A[0][x] = 'N'
        total += A[2][x]

        #print(Graph)
        #print(A)
        for neighbour in Graph[x]:
            #print(neighbour)
            #print(A[2][neighbour-1])
            if A[0][neighbour-1] != 'N':
                if A[2][neighbour-1] != None and Graph[x][neighbour] < A[2][neighbour-1]:
                    A[1][neighbour-1] = x
                    A[2][neighbour-1] = Graph[x][neighbour]
        count += 1
        current = T[count]

But I have no idea with what part. Getting pretty late and my head hurts, anyone who could help would be great.

EDIT Here's an example of the MST it generates. There are vertices with 0 weighted edges for some reason.

graph = construct_graph(20) Prim(graph) {3: 0, 5: 0, 8: 0, 16: 0, 6: 5, 9: 3, 7: 8, 11: 5, 15: 11, 12: 11, 2: 8, 18: 2, 19: 2, 1: 19, 10: 19, 14: 10, 17: 5, 13: 16, 4: 1}

(Looking at my code carefully, you can see that for value x:y, x is the value of the vertex while y is the weight of the connecting edge. For some reason there are vertices weighted 0)


Solution

  • After an advice, I changed this line of code:

    A[2][x] = 0
    

    To this:

    A[2][x] = math.inf
    

    This is so the array doesn't accidentally see 'woot, edge with 0 weights' since that should mean it's not connected. So it's all a matter of what to put in for the illegal value.