Search code examples
rsymbolsreinforcement-learning

R: matrix with directional arrows


I am trying to reproduce with R an algorithm described in Sutton and Barto (2018), but I was not able to produce a matrix with arrows as the one described by the authors on page 65:

enter image description here

I tried to use the package "fields" for this purpose, but without much success.

In Python the solution proposed by Shangtong Zhang and Kenta Shimada relies on using the arrows symbols: ACTIONS_FIGS=[ '←', '↑', '→', '↓'] but this does not work nicely with R...

EDIT: I coded the initial actions and the action updates numerically as follows:

library(data.table)
action_random = data.table(cell=c(1:25))
action_random$action_up = action_random$action_right = action_random$action_down =
action_random$action_left = rep(1,25)
action_random$proba = rep(1/4,25)
action_random

I was also able to adapt the code posted here, to draw a simple grid with simple arrows:

arrows = matrix(c("\U2190","\U2191","\U2192","\U2193"),nrow=2,ncol=2)
grid_arrows = expand.grid(x=1:ncol(arrows),y=1:nrow(arrows))
grid_arrows$val = arrows[as.matrix(grid_arrows[c('y','x')])]

library(ggplot2)

ggplot(grid_arrows, aes(x=x, y=y, label=val)) + 
  geom_tile(fill='transparent', colour = 'black') + 
  geom_text(size = 14) + 
  scale_y_reverse() +
  theme_classic() + 
  theme(axis.text  = element_blank(),
        panel.grid = element_blank(),
        axis.line  = element_blank(),
        axis.ticks = element_blank(),
        axis.title = element_blank())

However:
(i) There is no unicode available for the nice 2 are 4-directional arrows reported in Table $\pi_\ast$ above
(ii) ... and so I was not trying to code the bijection between the numerical values in the Table "action_random" and a nice Table with arrows in it...

Any hint helping to resolve issues (i) and (ii) are welcome.


Solution

  • Here is a grid+lattice way to reproduce the matrix:

    library(grid)
    library(lattice)
    
    grid.newpage()
    pushViewport(viewport(width = 0.8, height = 0.8)) 
    grid.rect(width = 1, height = 1)
    panel.grid(h = 4, v = 4)
    
    direct = function(xCenter, yCenter, type){
      
      d= 0.05
      
      north = function(xCenter, yCenter){ 
        grid.curve(xCenter, yCenter-d ,xCenter, yCenter+d, 
                   ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                   inflect = FALSE, shape = 0,
                   arrow = arrow(type="closed", ends = "last", 
                          angle = 30, length = unit(0.2, "cm")))}
      
      west = function(xCenter, yCenter){
        grid.curve(xCenter+d, yCenter ,xCenter-d, yCenter, 
                   ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                   inflect = FALSE, shape = 0,
                   arrow = arrow(type="closed", ends = "last", 
                                 angle = 30, length = unit(0.2, "cm")))}
      east = function(xCenter, yCenter){
        grid.curve(xCenter+d, yCenter ,xCenter-d, yCenter, 
                   ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                   inflect = FALSE, shape = 0,
                   arrow = arrow(type="closed", ends = "first", 
                                 angle = 30, length = unit(0.2, "cm")))}
      
      northeast = function(xCenter, yCenter){
           grid.curve(xCenter-d, yCenter+d ,xCenter+d, yCenter-d, 
                     ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                     inflect = FALSE, shape = 0,
                     arrow = arrow(type="closed", ends = "both", 
                             angle = 30, length = unit(0.2, "cm")))}
      
      northwest = function(xCenter, yCenter){
           grid.curve(xCenter-d, yCenter-d ,xCenter+d, yCenter+d, 
                   ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                   inflect = FALSE, shape = 0,
                   arrow = arrow(type="closed", ends = "both", 
                                 angle = 30, length = unit(0.2, "cm")))}
      all = function(xCenter, yCenter){
          grid.curve(xCenter+d, yCenter ,xCenter-d, yCenter, 
                     ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                     inflect = FALSE, shape = 0,
                     arrow = arrow(type="closed", ends = "both", 
                                   angle = 30, length = unit(0.2, "cm")))
          grid.curve(xCenter, yCenter-d ,xCenter, yCenter+d, 
                 ncp = 1, angle = 90, gp=gpar(lwd=1, fill="black"),
                 inflect = FALSE, shape = 0,
                 arrow = arrow(type="closed", ends = "both", 
                               angle = 30, length = unit(0.2, "cm")))}
      switch(type,
             'n' = north(xCenter, yCenter),
             'e' = east(xCenter, yCenter),
             'w' = west(xCenter, yCenter),
             'nw'= northwest(xCenter, yCenter),
             'ne' = northeast(xCenter, yCenter),
             'all' = all(xCenter, yCenter)
             )
    }
    
    x = seq(0.1, 0.9, by = 0.2)
    y = x
    centers = expand.grid(x0 = x, y0 = y)
    
    row1 = row2 = row3 = c('ne','n', rep('nw',3))
    row4 = c('ne','n','nw','w','w')
    row5 = c('e','all','w','all','w')
    
    dir = c(row1,row2,row3,row4,row5)
    df = data.frame(centers, dir)
    
    for (k in 1:nrow(df)) direct(df$x0[k], df$y0[k], df$dir[k])
    grid.text(bquote(~pi["*"]), y = -0.05)
    

    enter image description here