Search code examples
pythonkerasreinforcement-learningreward

How to train a bad reward with a classifying Neural Net?


I am trying to train a Neural Net on playing Tic Tac Toe via Reinforcement Learning with Keras, Python. Currently the Net gets an Input of the current board:

    array([0,1,0,-1,0,1,0,0,0])
1 = X 
-1 = O
0 = an empty field

If the Net won a game it gets a reward for every action(Output) it did. [0,0,0,0,1,0,0,0,0] If the Net loses I want to train it with a bad reward. [0,0,0,0,-1,0,0,0,0]

But currently I get a lot of 0.000e-000 accuracies.

Can I train a "bad reward" at all? Or if can't do it with -1 how should I do it instead?

Thanks in advance.


Solution

  • You need to backpropagate the reward won at the end of the game. Have a look at this tutorial.

    In short, from this tutorial :

    # at the end of game, backpropagate and update states value
    def feedReward(self, reward):
        for st in reversed(self.states):
            if self.states_value.get(st) is None:
                self.states_value[st] = 0
            self.states_value[st] += self.lr * (self.decay_gamma * reward 
                        - self.states_value[st])
            reward = self.states_value[st]
    

    As you can see, the reward in the step let's say 5 (end of the game) is backpropagated (not in the derivative sense) throught all steps before (4,3,2,1) with a decay rate. This is the way to go because tic-tac-toe is a game with a delayed reward, as opposed to classic reinforcement learning environments, where we usually have a reward (positive or negative) at each step. Here the reward of action at T depends on the final action at T+something. This final action gives a reward of 1 if it ended the game with a win, or a reward of -1 if the opponent played the last action and won.

    As for the accuracy, we don't use it as a metric in reinforcement learning. A good metric would be to observe the mean cumulative reward (which will be 0 if your agent wins half of the time, > 0 if it has learned something, or < 0 otherwise).