I am trying to implement an a* search algorithm in python on a 6*6 interconnected node grid, using networkx to organize nodes and matplotlib to display. I've got it working so it finds the shortest path, but without the heuristic, it's just brute force search- which is too costly. How can I assign x,y coordinates to my nodes when I create them or is there any other way to get the heuristic to work? (I know networkx has an inbuilt a* function I could use, but I want to demonstrate that I can implement the algorithm)
Here's the code ( a bit messy from StackOverflow formatting):
import networkx as nx
G=nx.Graph()
import matplotlib.pyplot as plt
def add_nodes():
G.add_nodes_from([0, 1, 2, 3, 4, 5, \
6, 7, 8, 9, 10, 11, \
12, 13, 14, 15, 16, 17, \
18, 19, 20, 21, 22, 23, \
24, 25, 26, 27, 28, 29,\
30, 31, 32, 33, 34, 35])
c = 0
for y in range (0,5):
for x in range (0,5):
G.add_node(c, pos=(x/10,y/10))
c=c+1
#http://stackoverflow.com/questions/477486/how-to-use-a-decimal-range-step-value
#prev code for brute force search:
#https://pastebin.com/DT76bvw5
#node pos: http://stackoverflow.com/questions/11804730/networkx-add-node-with-specific-position
#http://theory.stanford.edu/~amitp/GameProgramming/Heuristics.html
#http://zurb.com/forrst/posts/A_algorithm_in_python-B4c
for a in G.nodes():
if a not in (5, 11, 17, 23,29, 35):
G.add_edge(a, a+1)
if a not in (30, 31, 32, 33, 34, 35):
G.add_edge(a, a+6)
if a not in (5, 11, 17, 23, 29, 30, 31, 32, 33, 34, 35):
G.add_edge(a, a+7)
if a not in (0, 6, 12, 18, 24, 30, 31, 32, 33, 34, 35):
G.add_edge(a, a+5)
def heuristic(a, b):
(x1, y1) = a
(x2, y2) = b
return abs(x1-x2) + abs(y1-y2)
#def cost (from_node, to_node):
def a_star_search(graph, start, end):
#initialise open list
open_nodes = []
#initialise closed list
closed_nodes = {}
#put starting node on open list
open_nodes.append(start)
#initialise cost list
cost_so_far = {}
#no previous path
closed_nodes[start] = None
cost_so_far[start] = 0
#lists for colour:
eVisited= []
ePath = []
#while open list is not empty
while (len(open_nodes) != 0):
#pop q off the open list
current = open_nodes.pop()
#for each neighbour
for next in G.neighbors(current):
new_cost = cost_so_far[current] + G[current][next]['weight']
print('cost between '+ str(current) + ' and ' + str(next) + ' = ' + str(new_cost))
if next not in cost_so_far or new_cost< cost_so_far[next]:
cost_so_far[next] = new_cost
print('minimal cost for start to ' + str(next) + ' found')
#assign colour to show it's been added
eVisited.append((current, next))
#priority = new_cost + heuristic(end, next)
open_nodes.append(next) #(next, priority)
closed_nodes[next] = current
print('node connected: ' + str(next))
print(closed_nodes)
v = closed_nodes[end]
ePath.append((end, closed_nodes[end]))
while v != start:
ePath.append((v, closed_nodes[v]))
v = closed_nodes[v]
print(ePath)
return ePath, eVisited
add_nodes()
ePath, eVisited = a_star_search(G, 18, 3)
pos=nx.spectral_layout(G) #positions for all nodes(?)
#draw nodes
nx.draw_networkx_nodes(G, pos, node_size=300)
#draw edges
nx.draw_networkx_edges(G, pos, width=3)
nx.draw_networkx_edges(G, pos, edgelist=eVisited, width = 6, edge_color='g')
nx.draw_networkx_edges(G, pos, edgelist=ePath, width = 6, edge_color='b')
#labels
nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif')
plt.grid('on')
#disable axis
plt.axis('off')
#draw graph
plt.show()
You can assign a position to each node:
for n in G:
x, y = n // 6, n % 6 # row and column coordinates
G.node[n]['pos'] = (x, y)
With that, you can access the property.
for n, data in G.nodes(data=True):
print(n, data['pos'])
# (0, 0)
# (0, 1)
# (0, 2)
# ...
Edit: For people who find this useful in the future, as noted, this can be plotted with the following:
nx.draw(G, pos=nx.get_node_attributes(G, 'pos'))