Search code examples
pythonalgorithmartificial-intelligencepruning

Error in alpha beta prunning algorithm in python


In the following pruning, the alpha returned is correct while the beta remains the same, what am i doing wrong? It's a tree that has the following values at the bottom nodes

tree = [[[5, 1, 2], [8, -8, -9]], [[9, 4, 5], [-3, 4, 3]]]
root = 0
pruned = 0

def children(branch, depth, alpha, beta):
    global tree
    global root
    global pruned
    i = 0
    for child in branch:
        if type(child) is list:
            (nalpha, nbeta) = children(child, depth + 1, alpha, beta)
            if depth % 2 == 1:
                beta = nalpha if nalpha < beta else beta
            else:
                alpha = nbeta if nbeta > alpha else alpha
            branch[i] = alpha if depth % 2 == 0 else beta
            i += 1
        else:
            if depth % 2 == 0 and alpha < child:
                alpha = child
            if depth % 2 == 1 and beta > child:
                beta = child
            if alpha >= beta:
                pruned += 1
                break
    if depth == root:
        tree = alpha if root == 0 else beta
    return (alpha, beta)

def alphabeta(in_tree=tree, start=root, lower=-15, upper=15):
    global tree
    global pruned
    global root

    (alpha, beta) = children(tree, start, lower, upper)

    if __name__ == "__main__":
        print ("(alpha, beta): ", alpha, beta)
        print ("Result: ", tree)
        print ("Times pruned: ", pruned)

    return (alpha, beta, tree, pruned)


if __name__ == "__main__":
    alphabeta()

Is the codes even right, or should i approach it differently? EDIT The problem most likely stems from the modulo(%) in the beta section

EDIT2 UPDATED CODE

tree = [[[1, 8], [5], [6, 4, 7], [9], [3, 2], [6, 10, 2]]]
side = 1
alpha = -1000
beta = 1000
depth = 3
p = []
betacuts=[]
alphacuts=[]
counta=-1
countb=-1

def getLengthLoL(position):
    if len(position)==0:
        if isinstance(tree,int):
            return tree
        return len(tree)
    if len(position)==1:
        if isinstance(tree[p[0]],int):
            return tree[p[0]]
        return len(tree[p[0]])
    if len(position)==2:
        if isinstance(tree[p[0]][p[1]],int):
            return tree[p[0]][p[1]]
        return len(tree[p[0]][p[1]])
    if len(position)==3:
        if isinstance(tree[p[0]][p[1]][p[2]],int):
            return tree[p[0]][p[1]][p[2]]
        return len(tree[p[0]][p[1]][p[2]])
    if len(position)==4:
        if isinstance(tree[p[0]][p[1]][p[2][p[3]]],int):
            return tree[p[0]][p[1]][p[2][p[3]]]
        return len(tree[p[0]][p[1]][p[2][p[3]]])
def makeMove(move):
    global side
    if side:
        side = 0
    else:
        side = 1
    p.append(move)

def backMove(move):
    global side
    if side:
        side = 0
    else:
        side = 1
    p.pop()

def evaluation(score):
    if side==0:
        return -1*score
    else:
        return score

def minmax( alpha, beta, depth ):
    global counta
    global countb
    if depth==0:
        return evaluation(getLengthLoL(p))
    moves = getLengthLoL(p)
    for move in range(int(moves)):
        makeMove(move)
        val = -1*minmax(-beta,-alpha,depth-1)
        backMove(move)
        if val >= beta:
            betacuts.append(val)
            countb += 1
            beta=val;
            return beta;
        if val > alpha:
            alphacuts.append(val)
            counta += 1
            alpha = val;

    return alpha


myScore = minmax(alpha,beta,depth)
print (betacuts,alphacuts)
print (myScore)

This code is printing wrong alphas and betas from the start


Solution

  • so this is a more traditional approach. I have not double checked it but I know this is the correct approach. the variable p is encoding the "position". The code will only be accurate if the depth of all the tree's branches are the same. In this case that is why the depth variable is set to 3. A little more work is needed to make it run on any tree.

    tree = [[[0,1,2],[-1,2,5],[-2,2,0]],[[-2,-1,-3],[-4,-3,-1],[1,2,8]],[[4,6,1],[1,7,-1],[-2,-4,1]]]
    
    side = 1
    alpha = -1000
    beta = 1000
    depth = 3
    
    p = []
    def getLengthLoL(l, address):
        item = l
        for index in address:
            item = item[index]
        return len(item) if isinstance(item, list) else item
    
    def makeMove(move):
        global side
        if side:
            side = 0
        else:
            side = 1
        p.append(move)
    
    def backMove(move):
        global side
        if side:
            side = 0
        else:
            side = 1
        p.pop()
    
    def evaluation(score):
        if side==0:
            return -1*score
        else:
            return score 
    
    def minmax( alpha, beta, depth ):
        if depth==0:
            return evaluation(getLengthLoL(tree,p))
        moves = getLengthLoL(tree,p)
        for move in range(int(moves)):
            makeMove(move)
            val = -1*minmax(-beta,-alpha,depth-1)
            backMove(move)
            if val >= beta:
                return beta;        
            if val > alpha:
                alpha = val;
        return alpha        
    
    myScore = minmax(alpha,beta,depth)
    print myScore