Search code examples
machine-learningreinforcement-learningsarsa

Implementing SARSA using Gradient Discent


I have successfully implemented a SARSA algorithm (both one-step and using eligibility traces) using table lookup. In essence, I have a q-value matrix where each row corresponds to a state and each column to an action.

Something like:

[Q(s1,a1), Q(s1,a2), Q(s1,a3), Q(s1,a4)]
[Q(s2,a1), (Q(s2,a2), Q(s2a3), Q(s2, a2]
.
.
.
[Q(sn,a1), Q(sn,a2), Q(sn,a3), Q(sn,a4)]

At each time-step, a row from the matrix is picked and, depending on policy, an action is picked and updated according to SARSA rules.

I am now trying to implement it as a neural-network using gradient descent.

My first hypothesis was to create a two-layer network, the input layer having as many input neurons as there are states, and the output layer having as many output neurons as there are actions. Each input would be fully connected to each output. (So, in fact, it would look as the matrix above)

My input vector would be a 1xn row vector, where n is the number of input neurons. All values in the input vector would be 0, except for the index corresponding to the current state which would be 1. Ie:

[0 0 0 1 0 0]

Would be an input vector for an agent in state 4.

So, the process would be something like:

[0 0 0 1 0 0] X [ 4 7 9 3]
                [ 5 3 2 9]
                [ 3 5 6 9]
                [ 9 3 2 6]
                [ 2 5 7 8]
                [ 8 2 3 5]

Where I have created a random, sample weight-matrix.

The result would be:

[9 3 2 6]

Meaning that if a greedy policy was picked action 1 should be picked and the connection between the fourth input neuron and the first output neuron should become stronger by:

dw = dw_old + learning_rate*(reward + discount*network_output - dw_old)

(Equation taken from SARSA algorithm)

HOWEVER - this implementation doesn't convince me. According to what I read, the network weights should be used to calculate the Q-value of a state-action pair, but I'm not sure they should represent such values. (Especially because I've usually seen weight values only being included between 0 and 1.)

Any advice?


Solution

  • Summary: your current approach is correct, except that you shouldn't restrict your output values to be between 0 and 1.

    This page has a great explanation, which I will summarize here. It doesn't specifically discuss SARSA, but I think everything it says should translate.

    The values in the results vector should indeed represent your neural network's estimates for the Q-values associated with each state. For this reason, it's typically recommended that you not restrict the range of allowed values to be between zero and one (so just sum the values multiplied by connection weights, rather than using some sort of sigmoid activation function).

    As for how to represent the states, one option is to represent them in terms of sensors that the agent has or might theoretically have. In the example below, for instance, the robot has three "feeler" sensors, each of which can be in one of three conditions. Together, they provide the robot with all of the information it's going to get about which state it's in.

    enter image description here

    However, if you want to give your agent perfect information, you can imagine that it has a sensor that tells it exactly which state it is in, as shown near the end of this page. This would work exactly the way that your network is currently set up, with one input representing each state.