Search code examples
pythonperformanceminimax

Performance Optimization for Minimax Algorithm in Tic-Tac-Toe with Variable Board Sizes and Win Conditions


I'm implementing a Tic-Tac-Toe game with an AI using the Minimax algorithm (with Alpha-Beta Pruning) to select optimal moves. However, I'm experiencing performance issues when the board size increases or the number of consecutive marks required to win (k) grows. The current implementation works fine for smaller board sizes (e.g., 3x3), but as I scale up the board size to larger grids (e.g., 8x12), the performance drops significantly, especially with the depth of the search.

The algorithm is running too slowly on larger boards (e.g., 6x6 or 8x12) with more complex win conditions (e.g., 4 in a row, 5 in a row).

import random
import time

# Check if a player has won
def check_win(board, player, k):
    rows = len(board)
    cols = len(board[0])

    # Check horizontally, vertically, and diagonally
    for r in range(rows):
        for c in range(cols):
            if c <= cols - k and all(board[r][c+i] == player for i in range(k)):  # Horizontal
                return True
            if r <= rows - k and all(board[r+i][c] == player for i in range(k)):  # Vertical
                return True
            if r <= rows - k and c <= cols - k and all(board[r+i][c+i] == player for i in range(k)):  # Diagonal
                return True
            if r >= k - 1 and c <= cols - k and all(board[r-i][c+i] == player for i in range(k)):  # Anti-diagonal
                return True
    return False

# Heuristic function to evaluate the board state
def evaluate_board(board, player, k):
    # Check if the player has already won
    if check_win(board, player, k):
        return 1000  # Win
    elif check_win(board, 3-player, k):  # Opponent
        return -1000  # Loss
    return 0  # No winner yet

# Minimax algorithm with Alpha-Beta Pruning
def minimax(board, depth, alpha, beta, is_maximizing_player, player, k):
    # Evaluate the current board state
    score = evaluate_board(board, player, k)
    if score == 1000 or score == -1000:
        return score
    if depth == 0:
        return 0  # Depth limit reached

    if is_maximizing_player:
        best = -float('inf')
        for r in range(len(board)):
            for c in range(len(board[0])):
                if board[r][c] == 0:  # Empty cell
                    board[r][c] = player  # Make move
                    best = max(best, minimax(board, depth-1, alpha, beta, False, player, k))
                    board[r][c] = 0  # Undo move
                    alpha = max(alpha, best)
                    if beta <= alpha:
                        break
        return best
    else:
        best = float('inf')
        for r in range(len(board)):
            for c in range(len(board[0])):
                if board[r][c] == 0:  # Empty cell
                    board[r][c] = 3 - player  # Opponent's move
                    best = min(best, minimax(board, depth-1, alpha, beta, True, player, k))
                    board[r][c] = 0  # Undo move
                    beta = min(beta, best)
                    if beta <= alpha:
                        break
        return best

# Function to find the best move for the AI
def best_move(board, player, k):
    best_val = -float('inf')
    best_move = (-1, -1)
    for r in range(len(board)):
        for c in range(len(board[0])):
            if board[r][c] == 0:  # Empty cell
                board[r][c] = player  # Make move
                move_val = minimax(board, 4, -float('inf'), float('inf'), False, player, k)  # Depth 4 for searching
                board[r][c] = 0  # Undo move
                if move_val > best_val:
                    best_move = (r, c)
                    best_val = move_val
    return best_move

# Function to play the game
def play_game(board_size=(3, 3), k=3):
    board = [[0 for _ in range(board_size[1])] for _ in range(board_size[0])]
    player = 1  # 1 - X, 2 - O

    while True:
        # Display the board
        for row in board:
            print(' '.join(str(x) if x != 0 else '.' for x in row))
        print()

        # Check if someone has won
        if check_win(board, player, k):
            print(f"Player {player} wins!")
            break

        # AI's move
        if player == 1:
            row, col = best_move(board, player, k)
            print(f"AI (X) move: {row}, {col}")
        else:
            # Simple random move for player O
            row, col = random.choice([(r, c) for r in range(board_size[0]) for c in range(board_size[1]) if board[r][c] == 0])
            print(f"Player O move: {row}, {col}")

        board[row][col] = player
        player = 3 - player  # Switch player

        # Check for a draw
        if all(board[r][c] != 0 for r in range(board_size[0]) for c in range(board_size[1])):
            print("It's a draw!")
            break

# Speed test function
def test_speed(board_size=(3, 3), k=3):
    start_time = time.time()
    play_game(board_size, k)
    print(f"Game duration: {time.time() - start_time} seconds")

# Example: play on a 8x8 board, with 4 in a row to win
test_speed((8,8), 4)

I tried limiting the depth of the Minimax search to reduce the number of possibilities being evaluated. For example, I reduced the search depth to 2 for larger boards:

move_val = minimax(board, depth=2, alpha=-float('inf'), beta=float('inf'), is_maximizing_player=False, player=player, k=k)

However, this only partially mitigates the issue, and the AI's decision-making still seems too slow.

Also i explored parallelizing the move evaluation using the ThreadPoolExecutor from Python’s concurrent.futures module:

from concurrent.futures import ThreadPoolExecutor
def parallel_minimax(board, depth, alpha, beta, is_maximizing_player, player, k):
    with ThreadPoolExecutor() as executor:
        futures = []
        for r in range(len(board)):
            for c in range(len(board[0])):
                if board[r][c] == 0:
                    board[r][c] = player
                    futures.append(executor.submit(minimax, board, depth-1, alpha, beta, False, player, k))
                    board[r][c] = 0  # Undo move
        for future in futures:
            result = future.result()
            # [process the result]

This approach improved parallelism, but it still does not solve the underlying issue of time complexity for large boards.

Please help to make the algorithm as fast as possible


Solution

  • Minimax, even with alpha-beta pruning, will have to look at an exponentially growing number of states. For larger board sizes this will mean you can only perform shallow searches.

    I would suggest to switch to the Monte Carlo Search algorithm. It uses random sampling, making decisions whether to explore new branches in the search tree or to deepen existing ones. You can check out the Wikipedia page on Monte Carlo tree search for more information.

    You mention 4-in-a-row, but realise that on large boards (like 8x12), that is an easy win for the first player.

    Below is an implementation I made for answering your question. It defines a TicTacToe class with methods like best_move, play_row_col, ... But it inherits from a more generic class which provides the core Monte Carlo search functionality with its mc_search method. That superclass has no knowledge of the game logic; it only assumes two players and that it is turn-based with at each turn a finite number of moves. It will only refer to these moves with a sequential number, and will depend on the subclass methods to perform the corresponding moves. It is responsible for driving the search directions.

    The TicTacToe class does not have search logic. It depends on the superclass for that. Although it is not necessary, I decided to "help" find winning lines more effectively, and still add some logic in TicTacToe that narrows the list of moves that the algorithm should consider, in two ways:

    1. Only moves that are neighboring (possibly diagonally) occupied cells are considered. For the very first move only the center move is considered. This is a limitation that you might consider too strong, as in the beginning of the game on a large board, strong players may prefer to place their pieces further away from the other pieces. Still, I found that this restriction for the search algorithm still allows for reasonable good play (it's all relative to what strength you expect).

    2. If there is a winning move for the player who's turn it is, this is detected, and only that move will be in the move list. If the last played piece created a "threat" to win on their next move, this also leads to a move list with just one move for the opponent, since they will want to defend.

    These restrictions are only applicable for the search algorithm, not for the human player (of course).

    As on larger boards the game may need a lot of moves to come to an end, I also added a parameter to limit the search depth in the Monte Carlo rollout phase, and consider the outcome a draw when there is no win within that number of moves counting from the start of the rollout. It could be set to 20 or 30 for example.

    Here is the code:

    from __future__ import annotations
    from math import log
    from time import perf_counter_ns
    from random import shuffle, randrange
    from abc import abstractmethod
    from typing import Callable
    from enum import StrEnum
    from copy import deepcopy
    
    # Helper function to get best list entry based on callback function
    def best_index(lst: [any], evaluate: Callable[[any], int]) -> int:
        results = list(map(evaluate, lst))
        return results.index(max(results))
    
    # Rather generic, abstract class for 2-player game
    class MonteCarloGame:
        class States(StrEnum):
            FIRST_PLAYERS_TURN = "X"
            SECOND_PLAYERS_TURN = "O"
            FIRST_PLAYER_WON = "Player X won"
            SECOND_PLAYER_WON = "Player O won"
            DRAW = "It's a draw"
    
        TerminalStates = (States.FIRST_PLAYER_WON, States.DRAW, States.SECOND_PLAYER_WON)
        TurnStates = (States.FIRST_PLAYERS_TURN, States.SECOND_PLAYERS_TURN)
    
        # Node in the Monte Carlo search tree has the sum of scores over a number of games:
        class Node:
            def __init__(self, num_children: int, last_player: MonteCarloGame.States):
                self.last_player = last_player
                self.score = 0
                self.max_score = 0
                self.children: [MonteCarloGame.Node] = [None] * num_children
                self.unvisited = list(range(num_children))
                shuffle(self.unvisited)
    
            # To get a random child (i.e. move to next state)
            def pick_child(self) -> int:
                if self.unvisited:  # If there are unvisited child nodes, choose one of those first
                    return self.unvisited.pop()  # = a randomly chosen Node to expand
                if not self.children:
                    raise ValueError("Cannot call pick_child on a terminal node")
                # Use UCB1 formula to choose which node to explore further
                log_n = log(self.max_score)
                return best_index(self.children, lambda child: child.score / child.max_score + (log_n / child.max_score)**0.5)
    
            def update_score(self, score: int):
                self.max_score += 2  # Every simulated game counts for 2 points (i.e. the maximum score)
                # Make the score relative to the player who made the last move
                self.score += 2 - score if self.last_player == MonteCarloGame.States.FIRST_PLAYERS_TURN else score  # A score can be 0, 1 or 2 (loss, draw, win).
    
            def select_or_expand(self, index: int, num_children: int, last_player: MonteCarloGame.States) -> MonteCarloGame.Node:
                child = self.children[index] or MonteCarloGame.Node(num_children, last_player)
                self.children[index] = child
                return child
    
        def __init__(self):
            self.mc_root = None
    
        @abstractmethod
        def copy(self) -> MonteCarloGame:
            return self # must implement
    
        @abstractmethod
        def size_of_move_list(self) -> int:
            return 0 # must implement
    
        @abstractmethod
        def state(self) -> States:
            return self.States.FIRST_PLAYERS_TURN # must implement
    
        @abstractmethod
        def play_from_move_list(self, move_index: int) -> States:
            return self.States.FIRST_PLAYERS_TURN # must implement
    
        # To be called when a move is played from the subclass, so
        #   that the root can "follow" to the corresponding child node
        def mc_update(self, move_index: int):
            if not self.mc_root or move_index == -1:
                self.mc_root = MonteCarloGame.Node(self.size_of_move_list(), self.state())
            else:
                self.mc_root = self.mc_root.select_or_expand(move_index, self.size_of_move_list(), self.state())
    
        # Main Monte Carlo Search algorithm:
        def mc_search(self, timeout_ms: int, rollout_depth: int) -> int:
            if self.state() in self.TerminalStates:
                return -1 # Nothing to do
            if not self.mc_root:
                self.mc_root = MonteCarloGame.Node(self.size_of_move_list(), self.state())
            expiry = perf_counter_ns() + timeout_ms * 1_000_000
            while perf_counter_ns() < expiry:
                # Start at root (current game state)
                game = self.copy()
                node = self.mc_root
                state = game.state()
                path = [node]
                # Traverse down until we have expanded a node, or reached a terminal state
                while state in self.TurnStates and node.max_score:
                    move_index = node.pick_child()
                    next_state = game.play_from_move_list(move_index)
                    node = node.select_or_expand(move_index, game.size_of_move_list(), state)
                    path.append(node)
                    state = next_state
                # Rollout
                score = 1  # Default outcome is a draw (when depth limit is reached)
                for depth in range(rollout_depth):  # Limit the depth
                    if state not in self.TurnStates:
                        score = self.TerminalStates.index(state)  # Absolute score (good for player #2)
                        break
                    state = game.play_from_move_list(randrange(game.size_of_move_list()))
                # Back propagate the score up the search tree
                for node in path:
                    node.update_score(score)
    
            # Choose the most visited move
            return best_index(self.mc_root.children, lambda n: n.max_score if n else 0)
    
    
    class TicTacToe(MonteCarloGame):
        class Contents(StrEnum):
            INVALID = "#"
            EMPTY = "."
            FIRST_PLAYER = MonteCarloGame.States.FIRST_PLAYERS_TURN
            SECOND_PLAYER = MonteCarloGame.States.SECOND_PLAYERS_TURN
    
        def __init__(self, num_rows: int=3, num_cols: int=3, win_length: int=3, *, source: TicTacToe=None):
            super().__init__()
            if source:  # copy
                self.__dict__ = deepcopy(source.__dict__)
                return
            self.num_rows = num_rows
            self.num_cols = num_cols
            self.win_length = win_length
            width = num_cols + 1 # include dummy column
            bottom = width * (num_rows + 1)
            # Surround board with dummy row and column, and create a flat representation:
            self.board = [
                (self.Contents.INVALID if i < width or i >= bottom or i % width == 0 else self.Contents.EMPTY)
                for i in range(width * (num_rows + 2) + 1) # flat structure
            ]
            self.free_cells_count = num_cols * num_rows
            # Maintain a reasonable move list (not ALL free cells) to keep search efficient
            # At the start we only consider one move (can adapt to allow some more...)
            # This list is not relevant when a "human" player plays a move (any free cell is OK)
            self.move_list = [self.row_col_to_cell(num_rows//2, num_cols//2)]
            self._state = self.States.FIRST_PLAYERS_TURN
            # Add some information about current threats on the board (to narrow rollout paths)
            self.open_threats: [int] = []  # board cell indices where the last player has opportunity to win
            self.immediate_win = 0  # board cell index where the next move wins the game
            self.forced_cell = 0    # derived from previous two attributes
    
        def copy(self) -> TicTacToe:
            return TicTacToe(source=self)
    
        def size_of_move_list(self) -> int:
            if self._state in self.TerminalStates:
                return 0
            # If we are on a forced line to a win or to avoid a loss, consider one move only
            if self.forced_cell:
                return 1
            return len(self.move_list)
    
        def state(self) -> MonteCarloGame.States:
            return self._state
    
        def play_from_move_list(self, move_index: int) -> MonteCarloGame.States:
            return self.play_at_cell(self.move_index_to_cell(move_index))
    
        def play_at_cell(self, cell: int) -> MonteCarloGame.States:
            n = len(self.board)
            turn = self._state
            self._state = self.TurnStates[1-self.TurnStates.index(turn)] # toggle
            self.board[cell] = self.Contents(turn)  # Type cast to silence warning
            self.free_cells_count -= 1
            if cell in self.move_list:
                self.move_list.remove(cell)
            # Add surrounding cells to move_list for next turn
            w = self.num_cols
            for neighbor in (cell-1,cell+1,cell-w-2,cell-w-1,cell-w,cell+w,cell+w+1,cell+w+2):
                if self.board[neighbor] == self.Contents.EMPTY and neighbor not in self.move_list:
                    self.move_list.append(neighbor)
            threats = set()  # Collect cells where the move creates a threat (where opponent must play at)
            # check if this move makes a direct win or a threat
            for step in (1, self.num_cols, self.num_cols + 1, self.num_cols + 2):
                span = self.win_length * step
                i = next(k for k in range(cell - step, -1, -step) if self.board[k] != turn)
                j = next(k for k in range(cell + step, n, step) if self.board[k] != turn)
                if j - i > span:
                    # it's an immediately winning move
                    self._state = (self.States.FIRST_PLAYER_WON if turn == self.States.FIRST_PLAYERS_TURN
                                   else self.States.SECOND_PLAYER_WON)
                    return self._state
                # look for a single gap in a line, such that this gap is a threat
                if self.board[i] == self.Contents.EMPTY:
                    i2 = next(k for k in range(i - step, -1, -step) if self.board[k] != turn)
                    if j - i2 > span:
                        threats.add(i)
                if self.board[j] == self.Contents.EMPTY:
                    j2 = next(k for k in range(j + step, n, step) if self.board[k] != turn)
                    if j2 - i > span:
                        threats.add(j)
    
            if not self.free_cells_count:
                self._state = self.States.DRAW
    
            self.immediate_win = next((c for c in self.open_threats if c != cell), 0)
            self.open_threats = list(threats)
            self.forced_cell = self.immediate_win or next(iter(self.open_threats), 0)
            return self._state
    
        def play_row_col(self, row: int, col: int) -> MonteCarloGame.States:
            if not 0 <= row < self.num_rows or not 0 <= col < self.num_cols:
                raise ValueError("row/col out of range")
            cell = self.row_col_to_cell(row, col)
            if self.board[cell] != TicTacToe.Contents.EMPTY:
                raise ValueError("cell is already occupied")
            move_index = self.cell_to_move_index(cell)
            state = self.play_at_cell(cell)
            self.mc_update(move_index)  # must be called after the move is played
            return state
    
        # Some methods that convert between different ways to identify a move
        def move_index_to_cell(self, move_index: int) -> int:
            # Convert index in the move list to cell in the flattened board
            # If we are on a forced line to a win or to avoid a loss, consider that move only
            return self.forced_cell or self.move_list[move_index]
    
        def cell_to_move_index(self, cell: int) -> int:
            # Convert cell in the flattened board to index in the move list
            cells = [self.forced_cell] if self.forced_cell else self.move_list
            try:
                return cells.index(cell)
            except ValueError:
                return -1
    
        def row_col_to_cell(self, row: int, col: int) -> int:
            # Convert row / col to the cell in the flattened board
            return (row + 1) * (self.num_cols + 1) + col + 1
    
        def cell_to_row_col(self, cell: int) -> tuple[int, int]:
            # Convert cell in flat board structure back to row/col:
            return cell // (self.num_cols + 1) - 1, cell % (self.num_cols + 1) - 1
    
        # The method that starts the search for the best move, and returns it
        def best_move(self, timeout_ms: int, rollout_depth: int=15) -> tuple[int, int]:
            return self.cell_to_row_col(self.move_index_to_cell(
                self.mc_search(timeout_ms, rollout_depth)
            ))
    
        def __repr__(self) -> str:
            return "#" * (self.num_cols * 2 + 2) + " ".join(
                (content, "#\n#")[content == self.Contents.INVALID]
                for content in self.board[self.num_cols+1:-self.num_cols-2]
            ) + " #\n" + "#" * (self.num_cols * 2 + 3)
    
    
    def main():
        # Create the game instance and set parameters
        game = TicTacToe(num_rows=8, num_cols=12, win_length=5)
        timeout_ms = 2000  # milliseconds that tree search may spend for determining a good move
        rollout_depth = 20  # max number of moves in a rollout game: if reached it counts as a draw
        human = (MonteCarloGame.States.SECOND_PLAYERS_TURN, )
    
        # Game loop
        state = game.state()
        while state in MonteCarloGame.TurnStates:
            print(game)
            print(f"Player {state} to move...")
            if state in human:
                try:
                    state = game.play_row_col(int(input(f"Enter row (0-{game.num_rows-1}): ")),
                                              int(input(f"Enter column (0-{game.num_cols-1}): ")))
                except ValueError:
                    print("Invalid move. Try again:")
            else:
                state = game.play_row_col(*game.best_move(timeout_ms, rollout_depth))
                print("Computer has played:")
    
        print(game)
        print(f"Game over: {state}")
    
    main()
    

    The above can be run as-is and will play a computer-vs-human game on a 8x12 board, with a 5-in-a-row target, and the computer getting 2 seconds time per move, and a rollout depth of 20. Have a go at it, and try to beat it. It is possible.

    This is not an end-product, but I think this program plays reasonable games.

    For smaller boards, like 3x3, you don't need 2 seconds, and could set it to just 100 milliseconds.

    Several things could be improved, like:

    • The search could keep going during the time that the human player is to enter their move.

    • The search time could be more dynamic so that it would return a best move faster when it becomes clear that one move stands out among the rest.

    • More logic could be added to the TicTacToe class so that it would detect forced lines at an earlier stage and limit the move list (for the Monte Carlo search) to the relevant moves to follow those lines. Think of the creation of double threats, ...etc.

    • The depth limit for the rollout phase could be made more dynamic, depending on the size of the board, the time still available,...

    I hope this meets some of your requirements and this gives some leads on how to further improve it. At least I was happy with the result that less than 300 lines of code could give.