Search code examples
pythonalgorithmnetworkxa-starheuristics

Assigning x,y coords in networkx/python for a* search heuristic


I am trying to implement an a* search algorithm in python on a 6*6 interconnected node grid, using networkx to organize nodes and matplotlib to display. I've got it working so it finds the shortest path, but without the heuristic, it's just brute force search- which is too costly. How can I assign x,y coordinates to my nodes when I create them or is there any other way to get the heuristic to work? (I know networkx has an inbuilt a* function I could use, but I want to demonstrate that I can implement the algorithm)

Here's the code ( a bit messy from StackOverflow formatting):

import networkx as nx
G=nx.Graph()

import matplotlib.pyplot as plt


def add_nodes():
    G.add_nodes_from([0, 1, 2, 3, 4, 5, \
         6, 7, 8, 9, 10, 11, \
         12, 13, 14, 15, 16, 17, \
         18, 19, 20, 21, 22, 23, \
         24, 25, 26, 27, 28, 29,\
         30, 31, 32, 33, 34, 35])


    c = 0
    for y in range (0,5):
        for x in range (0,5):
            G.add_node(c, pos=(x/10,y/10))
            c=c+1
#http://stackoverflow.com/questions/477486/how-to-use-a-decimal-range-step-value
#prev code for brute force search:
#https://pastebin.com/DT76bvw5
#node pos: http://stackoverflow.com/questions/11804730/networkx-add-node-with-specific-position
#http://theory.stanford.edu/~amitp/GameProgramming/Heuristics.html
#http://zurb.com/forrst/posts/A_algorithm_in_python-B4c



for a in G.nodes():
    if a not in (5, 11, 17, 23,29, 35):
        G.add_edge(a, a+1)
    if a not in (30, 31, 32, 33, 34, 35):
        G.add_edge(a, a+6)
    if a not in (5, 11, 17, 23, 29, 30, 31, 32, 33, 34, 35):
        G.add_edge(a, a+7)
    if a not in (0, 6, 12, 18, 24, 30, 31, 32, 33, 34, 35):
        G.add_edge(a, a+5)

def heuristic(a, b):
    (x1, y1) = a
    (x2, y2) = b
    return abs(x1-x2) + abs(y1-y2)

#def cost (from_node, to_node):


def a_star_search(graph, start, end):    
    #initialise open list
    open_nodes = []
    #initialise closed list
    closed_nodes = {}
    #put starting node on open list
    open_nodes.append(start)
    #initialise cost list
    cost_so_far = {}
    #no previous path
    closed_nodes[start] = None
    cost_so_far[start] = 0


    #lists for colour:
    eVisited= []
    ePath = []


    #while open list is not empty
    while (len(open_nodes) != 0):
        #pop q off the open list
        current = open_nodes.pop()


        #for each neighbour

        for next in G.neighbors(current):
            new_cost = cost_so_far[current] + G[current][next]['weight']
            print('cost between '+ str(current) + ' and ' + str(next) + ' = ' + str(new_cost))
            if next not in cost_so_far or new_cost< cost_so_far[next]:
                cost_so_far[next] = new_cost
                print('minimal cost for start to ' + str(next) +  ' found')

                #assign colour to show it's been added
                eVisited.append((current, next))

                #priority = new_cost + heuristic(end, next)
                open_nodes.append(next) #(next, priority)
                closed_nodes[next] = current
                print('node connected: ' + str(next))

    print(closed_nodes)
    v = closed_nodes[end]
    ePath.append((end, closed_nodes[end]))
    while v != start:
        ePath.append((v, closed_nodes[v]))
        v = closed_nodes[v]
    print(ePath)
    return ePath, eVisited

add_nodes()
ePath, eVisited = a_star_search(G, 18, 3)
pos=nx.spectral_layout(G) #positions for all nodes(?)

#draw nodes
nx.draw_networkx_nodes(G, pos, node_size=300)


#draw edges
nx.draw_networkx_edges(G, pos, width=3)
nx.draw_networkx_edges(G, pos, edgelist=eVisited, width = 6, edge_color='g')
nx.draw_networkx_edges(G, pos, edgelist=ePath, width = 6, edge_color='b')

#labels
nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif')

plt.grid('on')

#disable axis
plt.axis('off')
#draw graph
plt.show()

Solution

  • You can assign a position to each node:

    for n in G:
        x, y = n // 6, n % 6  # row and column coordinates
        G.node[n]['pos'] = (x, y)
    

    With that, you can access the property.

    for n, data in G.nodes(data=True):
        print(n, data['pos'])
    # (0, 0)
    # (0, 1)
    # (0, 2)
    # ...
    

    Edit: For people who find this useful in the future, as noted, this can be plotted with the following:

    nx.draw(G, pos=nx.get_node_attributes(G, 'pos'))