Search code examples
matrixindexingluamaxtorch

Torch - Lua / Get max index from matrix


I am trying to code a neural net for domineering
The input is a matrix of 8 x 8 x 3. I organised the matrix as following :
The first depth is for the state of the game, the second depth is flipped board and the last depth is the player plane
The output is 8 x 8 is the best game to play aka the move to learn (generated by Monte Carlo Tree Search)

Then the network is a 8 x 8 tensor with the probability of being the best game to play, I need to get the index (x,y) of the max probability of the tensor for me to

I tried with the function torch.max(tensor, 2) and torch.max(tensor?1) but I didn't get what I need.

Can someone have any clue to help me ?

Thank you a lot !

#out = output of the neural net and output is the target output[indice][1]
# need to check if the target is the same as prediction
max, bestTarget = torch.max(output[index][1],2)
maxP, bestPrediction = torch.max(out,2)
max, indT = torch.max(max,1)
maxP, indP = torch.max(maxP,1)

Solution

  • To get the maximum element (best_row, best_col) of out,

    -- First get the maximum element and indices per row
    maxP_per_row, bestColumn_per_row = torch.max(out,2)
    -- then get the best element and the best row
    best_p, best_row = torch.max(maxP_per_row, 1)
    -- then find the best column for the best row
    best_col = bestColumn_per_row[best_row]
    

    You can do the same for target. Hope this helps.