Search code examples
javaalgorithmbinary-search-treeminimaxalpha-beta-pruning

Alpha Beta Pruning with Binary Search Tree


I am working through the Minimax algorithm with Alpha-Beta Pruning example found here. In the example, they use an array to implement the search tree. I followed the example, but also tried implementing it with a binary search tree as well. Here are the values I'm using in the tree: 3, 5, 6, 9, 1, 2, 0, -1.

The optimal value at the end should be 5. With the BST implementation, I keep getting 2.

I think this is the problem, but I don't know how to get around it:
I wrote the code to return out of recursion if it sees a leaf node to stop from getting null pointer exceptions when trying to check the next value. But instead, I think it's stopping the search too early (based off of what I see when stepping through the code with the debugger). If I remove the check though, the code fails on a null pointer.

Can someone point me in the right direction? What am I doing wrong?

Here's the code:

public class AlphaBetaMiniMax {

    private static BinarySearchTree myTree = new BinarySearchTree();
    static int MAX = 1000;
    static int MIN = -1000;
    static int opt;

    public static void main(String[] args) {
        //Start constructing the game
        AlphaBetaMiniMax demo = new AlphaBetaMiniMax();

        //3, 5, 6, 9, 1, 2, 0, -1
        demo.myTree.insert(3);
        demo.myTree.insert(5);
        demo.myTree.insert(6);
        demo.myTree.insert(9);
        demo.myTree.insert(1);
        demo.myTree.insert(2);
        demo.myTree.insert(0);
        demo.myTree.insert(-1);

        //print the tree
        System.out.println("Game Tree: ");
        demo.myTree.printTree(demo.myTree.root);

        //Print the results of the game
        System.out.println("\nGame Results:");

        //run the  minimax algorithm with the following inputs
        int optimalVal = demo.minimax(0, myTree.root, true, MAX, MIN);
        System.out.println("Optimal Value: " + optimalVal);

    }

    /**
     * @param alpha = 1000
     * @param beta = -1000
     * @param nodeIndex - the current node
     * @param depth - the depth to search
     * @param maximizingPlayer - the current player making a move
     * @return - the best move for the current player
     */
    public int minimax(int depth, MiniMaxNode nodeIndex, boolean maximizingPlayer, double alpha, double beta) {

        //Base Case #1: Reached the bottom of the tree
        if (depth == 2) {
            return nodeIndex.getValue();
        }

        //Base Case #2: if reached a leaf node, return the value of the current node
        if (nodeIndex.getLeft() == null && maximizingPlayer == false) {
            return nodeIndex.getValue();
        } else if (nodeIndex.getRight() == null && maximizingPlayer == true) {
            return nodeIndex.getValue();
        }

        //Mini-Max Algorithm
        if (maximizingPlayer) {
            int best = MIN;

            //Recur for left and right children
            for (int i = 0; i < 2; i++) {

                int val = minimax(depth + 1, nodeIndex.getLeft(), false, alpha, beta);
                best = Math.max(best, val);
                alpha = Math.max(alpha, best);

                //Alpha Beta Pruning
                if (beta <= alpha) {
                    break;
                }
            }
            return best;
        } else {
            int best = MAX;

            //Recur for left and right children
            for (int i = 0; i < 2; i++) {

                int val = minimax(depth + 1, nodeIndex.getRight(), true, alpha, beta);
                best = Math.min(best, val);
                beta = Math.min(beta, best);

                //Alpha Beta Pruning
                if (beta <= alpha) {
                    break;
                }
            }
            return best;
        }
    }
}

Output:

Game Tree: 
-1 ~ 0 ~ 1 ~ 2 ~ 3 ~ 5 ~ 6 ~ 9 ~ 
Game Results:
Optimal Value: 2

Solution

  • Your problem is your iterations are depending on a loop control of 2, and not a node == null finding for nodeIndex.getRight()(for max) getLeft(for min.)

    Remember a tree has 1 head(first level)

    2nd level = 2

    3rd level = 4

    4th 8 and so on. So your algorithm for looping will not even go down 3 levels.

    for (int i = 0; i < 2; i++) {
    
         int val = minimax(depth + 1, nodeIndex.getLeft(), false, alpha, beta);
                    best = Math.max(best, val);
                    alpha = Math.max(alpha, best);
    
                    //Alpha Beta Pruning
                    if (beta <= alpha) {
                        break;
                    }
    

    Change your loops to control iteration correctly and you should find the highest value easily.