Search code examples
neural-networkdecision-treereinforcement-learningq-learningfunction-approximation

Solving GridWorld using Q-Learning and function approximation


I'm studying the simple GridWorld (3x4, as described in Russell & Norvig Ch. 21.2) problem; I've solved it using Q-Learning and a QTable, and now I'd like to use a function approximator instead of a matrix.

I'm using MATLAB and have tried both neural networks and decision trees, but not getting the expected results, i.e. a bad policy is found. I've read some papers about the topic, but most of them are theoretical and don't dwell much on actual implementation.

I've been using offline learning because it's simpler. My approach goes like this:

  1. Initialize a decision tree (or NN) with 16 input binary units - one for each position in the grid plus the 4 possible actions (up, down, left, right).
  2. Make a lot of iterations, saving for each of them the qstate and the calculated qvalue in a training set.
  3. Train the decision tree (or NN) using the training set.
  4. Erase the training set and Repeat from step 2, using the just trained decision tree (or NN) to calculate qvalues.

It seems as it is too simple to be true, and indeed I don't get the expected results. Here's some MATLAB code:

retrain = 1;
if(retrain) 
    x = zeros(1, 16); %This is my training set
    y = 0;
    t = 0; %Iterations
end
tree = fitrtree(x, y);
x = zeros(1, 16);
y = 0;
for i=1:100
    %Get the initial game state as a 3x4 matrix
    gamestate = initialstate();
    end = 0;
    while (end == 0)
        t = t + 1; %Increase the iteration

        %Get the index of the best action to take
        index = chooseaction(gamestate, tree);

        %Make the action and get the new game state and reward
        [newgamestate, reward] = makeaction(gamestate, index);

        %Get the state-action vector for the current gamestate and chosen action
        sa_pair = statetopair(gamestate, index);

        %Check for end of game
        if(isfinalstate(gamestate))
            end = 1;
            %Get the final reward
            reward = finalreward(gamestate);
            %Add a sample to the training set
            x(size(x, 1)+1, :) = sa_pair;
            y(size(y,  1)+1, 1) = updateq(reward, gamestate, index, newgamestate, tree, t, end);
        else
            %Add a sample to the training set
            x(size(x, 1)+1, :) = sa_pair;
            y(size(y, 1)+1, 1) = updateq(reward, gamestate, index, newgamestate, tree, t, end);
        end

        %Update gamestate
        gamestate = newgamestate;
    end
end

It chooses a random action half the time. updateq function is:

function [ q ] = updateq( reward, gamestate, index, newgamestate, tree, iteration, finalstate )

alfa = 1/iteration;
gamma = 0.99;

%Get the action with maximum qvalue in the new state s'
amax = chooseaction(newgamestate, tree);

%Get the corresponding state-action vectors
newsa_pair = statetopair(newgamestate, amax);    
sa_pair = statetopair(gamestate, index);

if(finalstate == 0)
    X = reward + gamma * predict(tree, newsa_pair);
else
    X = reward;
end

q = (1 - alfa) * predict(tree, sa_pair) + alfa * X;    

end

Any suggestion would be greatly appreciated!


Solution

  • The issue was that in offline Q-Learning you need to repeat the process of gathering data at least n times, where n depends on the problem you're trying to model. If you analyze the qvalues calculated during each iteration and think about it, it becomes immediately clear why this is needed.

    In the first iteration you're learning only final states, in the second iteration you're also learning penultimate states, in the third iteration you're also learning antepenultimate states, and so on. You're learning from the final state to the initial state, propagating back the qvalues. In the GridWorld example, the minimum number of visited states needed to end the game is 6.

    Finally, the correct algorithm becomes:

    1. Initialize a decision tree (or NN) with 16 input binary units - one for each position in the grid plus the 4 possible actions (up, down, left, right).
    2. Make a lot of iterations (for this GridWorld example 30 games are enough), saving for each of them the qstate and the calculated qvalue in a training set.
    3. Train the decision tree (or NN) using the training set.
    4. Erase the training set.
    5. Repeat from step 2, using the just trained decision tree (or NN) to calculate qvalues, at least n times, where n depends on your problem. For this GridWorld example n is 6, but you'll get better results for all states if you repeat the process 7-8 times.