Search code examples
c#.netalgorithma-star

Why isn't my A* implementation finding the shortest path but rather the first path?


Here is a visual representation of what's happening

enter image description here

For some reason my implementation doesn't seem to calculate the shortest path, but rather the first path it finds, and I can't seem to figure out why.

Is there something that I've missed that's super obvious?

Here's the notes I took while learning and that's what I've been going over

  1. Initialize the open set with the starting node and an empty closed set.

  2. While the open set is not empty, select the node with the lowest cost to reach it from the starting node, plus the estimated cost to reach the goal node from that node, and remove it from the open set.

  3. If the selected node is the goal node, the algorithm terminates and the path is traced back from the goal node to the starting node.

  4. Otherwise, add the selected node to the closed set and evaluate its neighboring nodes. For each neighboring node that is not in the closed set and is not an obstacle, calculate its tentative cost to reach it from the starting node and the estimated cost to reach the goal node from that node, using the heuristic function. Add the neighboring nodes to the open set if they are not already in it.

  5. Repeat steps 2-4 until the goal node is found or the open set is empty.

The A* algorithm should guarantee that it will find the shortest path from the starting node to the destination node, as long as the heuristic function is admissible (i.e., it never overestimates the actual cost to reach the goal node) and the grid does not contain cycles or negative edge weights.

public class World
{
    public int Width { get; set; }
    public int Height { get; set; }

    public Node[,] Nodes { get; set; }
    public List<Node> Path { get; set; }

    public World(int width, int height)
    {
        Width = width;
        Height = height;
        Nodes = new Node[Width, Height];

        Build();
    }

    private void Build()
    {
        for (int y = 0; y < Height; y++)
        {
            for (int x = 0; x < Width; x++)
            {
                Nodes[x, y] = new Node(x, y);
            }
        }
    }

    public void Find()
    {
        Node Start = GetStartNode(Nodes);
        Node Destination = GetDestinationNode(Nodes);

        var openSet = new List<Node>();
        var closedSet = new List<Node>();
        openSet.Add(Start);

        while (openSet.Any())
        {
            Node n = GetLowestFCostNode(openSet);
            
            if (n.NodeState == NodeState.Destination)
            {
                /* Trace Back Path */
                Path = TracebackPath(n, Start);
                break;
            }

            closedSet.Add(n);
            openSet.Remove(n);

            foreach (var node in GetAdjacentNodes(n, Nodes))
            {
                if (closedSet.Contains(node) || node.NodeState == NodeState.Obstruction)
                    continue;

                int currentGCost = node.GCost + NodeDistance(n, node);

                bool recalculate;
                if (!openSet.Contains(node))
                {
                    openSet.Add(node);
                    recalculate = true;
                }
                else if (currentGCost < node.GCost)
                {
                    recalculate = true;
                }
                else
                {
                    recalculate = false;
                }

                if (recalculate)
                {
                    node.Parent = n;
                    node.GCost = currentGCost;
                    node.HCost = HueristicsCost(node, Destination);
                    node.CalculateFCost();
                }
            }
        }
    }

    private int HueristicsCost(Node node, Node destination)
    {
        int dx = Math.Abs(node.X - destination.X);
        int dy = Math.Abs(node.Y - destination.Y);
        return 10 * (dx + dy);
    }

    private int NodeDistance(Node a, Node b)
    {
        if (Math.Abs(a.X - b.X) == 1 && Math.Abs(a.Y - b.Y) == 1)
            return 14;
        return 10;
    }

    private List<Node> GetAdjacentNodes(Node candidate, Node[,] fields)
    {
        var fieldList = new List<Node>();
        var width = fields.GetLength(0);
        var height = fields.GetLength(1);

        /* Check Lateral Neighbors */
        for (var x = candidate.X - 1; x <= candidate.X + 1; x++)
        {
            /* Check Vertical Neighbors */
            for (var y = candidate.Y - 1; y <= candidate.Y + 1; y++)
            {
                /* Bounds Check */
                if (x >= 0 && x < width && y >= 0 && y < height && (x != candidate.X || y != candidate.Y))
                {
                    fieldList.Add(fields[x, y]);
                }
            }
        }

        return fieldList;
    }


    private Node GetLowestFCostNode(List<Node> openSet)
    {
        Node lowestFCostNode = null;
        int lowestFCost = int.MaxValue;

        foreach (Node node in openSet)
        {
            if (node.FCost < lowestFCost || (node.FCost == lowestFCost && node.HCost < lowestFCostNode.HCost))
            {
                lowestFCost = node.FCost;
                lowestFCostNode = node;
            }
        }

        return lowestFCostNode;
    }

    private List<Node> TracebackPath(Node node, Node start)
    {
        List<Node> list = new List<Node>();
        while (node != start)
        {
            list.Add(node);
            node.Tile.Fill = Brushes.Cyan;
            node = node.Parent;
        }

        list.Add(start);

        return list;
    }

    private Node GetDestinationNode(Node[,] nodes)
    {
        foreach (var node in nodes)
        {
            if (node.NodeState == NodeState.Destination)
                return node;
        }

        throw new Exception("No Destination Field.");
    }

    private Node GetStartNode(Node[,] nodes)
    {
        foreach (var node in nodes)
        {
            if (node.NodeState == NodeState.Start)
                return node;
        }

        throw new Exception("Could not find a starting node.");
    }
}

Node.cs

public class Node
{
    public int X { get; }
    public int Y { get; }
    public Rectangle Tile { get; set; }

    /* Estimated Distance from CurrentNode to the StartNode */
    public int GCost { get; set; }

    /* Estimated Distance From CurrentNode To DestinationNode */
    public int HCost { get; set; }

    /* G + HCost Combined */
    public int FCost { get; set; }
    public Node Parent { get; set; }

    private NodeState _nodeState;

    public NodeState NodeState
    {
        get { return _nodeState; }
        set
        {
            InvokeNodeState(value);
            _nodeState = value;
        }
    }

    public Node(int x, int y)
    {
        X = x;
        Y = y;

        CreateTile();
    }

    private void CreateTile()
    {
        Tile = new Rectangle()
        {
            Width = 25,
            Height = 25,
            Fill = Brushes.ForestGreen,
            Stroke = Brushes.Black,
            StrokeThickness = 2
        };

        Tile.MouseDown += (sender, args) =>
        {
            switch (NodeState)
            {
                case NodeState.None:
                    NodeState = NodeState.Obstruction;
                    break;
                case NodeState.Obstruction:
                    NodeState = NodeState.Start;
                    break;
                case NodeState.Start:
                    NodeState = NodeState.Destination;
                    break;
                case NodeState.Destination:
                    NodeState = NodeState.None;
                    break;
                default:
                    throw new ArgumentOutOfRangeException();
            }
        };

        Canvas.SetLeft(Tile, X * 25);
        Canvas.SetTop(Tile, Y * 25);
    }

    private void InvokeNodeState(NodeState value)
    {
        switch (value)
        {
            case NodeState.None:
                Tile.Fill = Brushes.ForestGreen;
                break;
            case NodeState.Obstruction:
                Tile.Fill = Brushes.SaddleBrown;
                break;
            case NodeState.Start:
                Tile.Fill = Brushes.Yellow;
                break;
            case NodeState.Destination:
                Tile.Fill = Brushes.Red;
                break;
            default:
                throw new ArgumentOutOfRangeException(nameof(value), value, null);
        }
    }

    public void CalculateFCost()
    {
        FCost = GCost + HCost;
        Tile.Fill = Brushes.DarkGreen;
        Tile.ToolTip = $"FCost: {FCost} - GCost: {GCost} - HCost: {HCost}";
    }
}

public enum NodeState
{
    None,
    Obstruction,
    Start,
    Destination
}

Solution

  • The heuristic used to estimate the remaining distance should never overestimate the remaining distance. Your heuristic is:

    int dx = Math.Abs(node.X - destination.X);
    int dy = Math.Abs(node.Y - destination.Y);
    return 10 * (dx + dy);
    

    while your distance calculation is

    if (Math.Abs(a.X - b.X) == 1 && Math.Abs(a.Y - b.Y) == 1)
        return 14;
    return 10;
    

    So consider nodes (0, 0) and (1,1), the estimated distance is 20, while the actual distance is 14. This makes the heuristic inadmissible. The simplest fix for this is to change (dx + dy) to Math.Max(dx, dy).

    When debugging A* it may be useful to set the heuristic to 0, this should turn the algorithm into Djikstra. If that works better you know your problem is related to the heuristic.

    Some notes on data structures, if you want to improve the performance of your algorithm you should probably use something other than List for your open/closed set. For the closed set some common choices are a HashSet, or a 2D array. For the open set the typical data structure is a MinHeap, since that has both fast inserts and removals.