Search code examples
rsimulationreinforcement-learning

n-armed bandit simulation in R


I'm using Sutton & Barto's ebook Reinforcement Learning: An Introduction to study reinforcement learning. I'm having some issues trying to emulate the results (plots) on the action-value page.

More specifically, how can I simulate the greedy value for each task? The book says:

...we can plot the performance and behavior of various methods as they improve with experience over 1000 plays...

So I guess I have to keep track of the exploratory values as better ones are found. The issue is how to do this using the greedy approach - since there are no exploratory moves, how do I know what is a greedy behavior?

Thanks for all the comments and answers!

UPDATE: See code on my answer.


Solution

  • I finally got this right. The eps player should beat the greedy player because of the exploratory moves, as pointed out int the book. The code is slow and need some optimizations, but here it is:

    enter image description here

    get.testbed = function(arms = 10, plays = 500, u = 0, sdev.arm = 1, sdev.rewards = 1){
    
      optimal = rnorm(arms, u, sdev.arm)
      rewards = sapply(optimal, function(x)rnorm(plays, x, sdev.rewards))
    
      list(optimal = optimal, rewards = rewards)
    }
    
    play.slots = function(arms = 10, plays = 500, u = 0, sdev.arm = 1, sdev.rewards = 1, eps = 0.1){
    
      testbed = get.testbed(arms, plays, u, sdev.arm, sdev.rewards)
      optimal = testbed$optimal
      rewards = testbed$rewards
    
      optim.index = which.max(optimal)
      slot.rewards = rep(0, arms)
      reward.hist = rep(0, plays)
      optimal.hist = rep(0, plays)
      pulls = rep(0, arms)
      probs = runif(plays)
    
      # vetorizar
      for (i in 1:plays){
    
          ## dont use ifelse() in this case
          ## idx = ifelse(probs[i] < eps, sample(arms, 1), which.max(slot.rewards))
    
          idx = if (probs[i] < eps) sample(arms, 1) else which.max(slot.rewards)
          reward.hist[i] = rewards[i, idx]
    
          if (idx == optim.index)
            optimal.hist[i] = 1
    
          slot.rewards[idx] = slot.rewards[idx] + (rewards[i, idx] - slot.rewards[idx])/(pulls[idx] + 1)
          pulls[idx] = pulls[idx] + 1
      }
    
      list(slot.rewards = slot.rewards, reward.hist = reward.hist, optimal.hist = optimal.hist, pulls = pulls)
    }
    
    do.simulation = function(N = 100, arms = 10, plays = 500, u = 0, sdev.arm = 1, sdev.rewards = 1, eps = c(0.0, 0.01, 0.1)){
    
      n.players = length(eps)
      col.names = paste('eps', eps)
      rewards.hist = matrix(0, nrow = plays, ncol = n.players)
      optim.hist = matrix(0, nrow = plays, ncol = n.players)
      colnames(rewards.hist) = col.names
      colnames(optim.hist) = col.names
    
      for (p in 1:n.players){
        for (i in 1:N){
          play.results = play.slots(arms, plays, u, sdev.arm, sdev.rewards, eps[p])
          rewards.hist[, p] = rewards.hist[, p] + play.results$reward.hist
          optim.hist[, p] = optim.hist[, p] + play.results$optimal.hist
        } 
      }
    
      rewards.hist = rewards.hist/N
      optim.hist = optim.hist/N
      optim.hist = apply(optim.hist, 2, function(x)cumsum(x)/(1:plays))
    
      ### Plot helper ###
      plot.result = function(x, n.series, colors, leg.names, ...){
        for (i in 1:n.series){
          if (i == 1)
            plot.ts(x[, i], ylim = 2*range(x), col = colors[i], ...)
          else
            lines(x[, i], col = colors[i], ...)
          grid(col = 'lightgray')
        }
        legend('topleft', leg.names, col = colors, lwd = 2, cex = 0.6, box.lwd = NA)
      }
      ### Plot helper ###
    
      #### Plots ####
      require(RColorBrewer)
      colors = brewer.pal(n.players + 3, 'Set2')
      op <-par(mfrow = c(2, 1), no.readonly = TRUE)
    
      plot.result(rewards.hist, n.players, colors, col.names, xlab = 'Plays', ylab = 'Average reward', lwd = 2)
      plot.result(optim.hist, n.players, colors, col.names, xlab = 'Plays', ylab = 'Optimal move %', lwd = 2)
      #### Plots ####
    
      par(op)
    }
    

    To run it just call

    do.simulation(N = 100, arms = 10, eps = c(0, 0.01, 0.1))