Search code examples
rreinforcement-learningsample

How can I fix the error in the Q-Learning algorithm in R?


I am trying to implement the Q-Learning algorithm in R :


# Define the map
map <- matrix(c(0, 1, 1, 0, 0, 0, 0, 1), nrow = 2, ncol = 4, byrow = TRUE)
# State labels
rownames(map) <- c("Start", "End")
# Action labels
colnames(map) <- c("Up", "Down", "Left", "Right")
# Rewards for each state-action pair
rewards <- matrix(c(-1, -1, -1, -1, -1, -1, -1, 10), nrow = 2, ncol = 4, byrow = TRUE)

# Q-Learning Algorithm
q_learning <- function(P, R, gamma = 0.9, alpha = 0.1, epsilon = 0.1, max_iter = 1000) {
  # Initialize the Q-value function
  Q <- matrix(rep(0, nrow(P) * ncol(P)), nrow = nrow(P), ncol = ncol(P))
  # Initialize the state
  state <- sample(1:nrow(P), 1)
  # Iterate until convergence or maximum iterations reached
  for (i in 1:max_iter) {
    # Choose an action using epsilon-greedy policy
    if (runif(1) < epsilon) {
      action <- sample(1:ncol(P), 1)
    } else {
      action <- which.max(Q[state, ])
    }
    # Observe the next state and reward
    prob <- P[state, action]
    next_state <- sample(1:nrow(P), 1, prob = prob)
    reward <- R[state, action]
    # Update the Q-value function
    Q[state, action] <- Q[state, action] + alpha * (reward + gamma * max(Q[next_state, ]) - Q[state, action])
    # Update the state
    state <- next_state
  }
  # Derive the optimal policy (argmax in R using the which.max)
  policy <- apply(Q, 1, which.max)
  # Return the Q-value function and policy
  return(list(Q = Q, policy = policy))
}


# Run the Q-Learning Algorithm on the map
q_learning(P = map, R = rewards, gamma = 0.9, alpha = 0.1, epsilon = 0.1, max_iter = 1000)


I receive an error in the sample function of incorrect number of probabilities.

Error in sample.int(length(x), size, replace, prob) :
incorrect number of probabilities

How can I fix it ?


Solution

  • I'm not familiar with the algorithm, but, at a guess from looking at the code, you could try

    prob <- P[ , action] 
    

    which will create a vector of length nrow(P). You will need to work through the logic for yourself!