Search code examples
rjagsrjags

JAGS: Use cell matrix as data input for excluding NA


Below is my model and example data I would use. There are some NA in the data that I needed to set priors to generate numbers but this way may cause some errors. I was wondering if I could just let JAGS skip the NA so like have a matrix with different rows and columns.

NAs are in ex_expectancy and ex_shock.

# data
# 3 subjects * 14 trials
ex_expectancy <- structure(list(`1` = c(9L, 5L, 1L), `2` = c(5L, 6L, 1L), `3` = c(2L, 7L, 4L), `4` = c(3L, 6L, 2L), `5` = c(9L, 6L, 4L), `6` = c(9L, 7L, 1L), `7` = c(3L, 5L, 5L), `8` = c(8L, 5L, 1L), `9` = c(10L, 5L, NA), `10` = c(9L, NA, NA), `11` = c(2L, NA, NA), `12` = c(3L,NA, NA), `13` = c(3L, NA, NA), `14` = c(4L, NA, NA)), row.names = c(NA,-3L), class = c("data.table", "data.frame"))
ex_shock <- structure(list(`1` = c(0L, 1L, 1L), `2` = c(0L, 1L, 0L), `3` = c(1L, 0L, 1L), `4` = c(1L, 0L, 1L), `5` = c(0L, 1L, 1L), `6` = c(0L, 0L, 0L), `7` = c(1L, 0L, 1L), `8` = c(1L, 1L, 1L), `9` = c(0L,1L, NA), `10` = c(1L, NA, NA), `11` = c(1L, NA, NA), `12` = c(0L, NA, NA), `13` = c(1L, NA, NA), `14` = c(0L, NA, NA)), row.names = c(NA,-3L), class = c("data.table", "data.frame"))

v <- matrix(NA, nrow=3,ncol=14)
v[,1] <- 0 # first v is 0

dlist <- list(
  Nsubjects = 3,
  Ntrials = 14,
  expectancy = ex_expectancy, 
  shock = ex_shock,
  v=v
)

myinits <-  list(list(
  alpha = runif (3,0,1))) # 3 subjects
parameters <- c('alpha','v','predk','scale','c','tau')

# model
RW <- function(){
  for (i in 1:Nsubjects)
  {
    for (j in 2:Ntrials) # for each trial
    {
      expectancy [i,j] ~ dnorm (scale [i] * v[i,j] + c[i],tau[i,j])
      # posteiror predictive
      predk [i,j] ~ dnorm (scale [i] * v[i,j] + c[i],tau[i,j])
      
      
      pe [i,j-1] <- shock [i,j-1] - v [i,j-1]
      v [i,j] <- v [i,j-1] + alpha [i]  * pe [i,j-1]
    }
  }
  # priors
  for (i in 1: Nsubjects){
    alpha [i] ~ dunif (0,1)
    scale [i] ~ dunif (0,10)
    c[i] ~ dunif (0,5)
    
    for (j in 1:Ntrials){
      sigma[i,j] ~ dunif (0,5)
      tau [i,j] <- 1/pow(sigma [i,j],2)
    }}
}


samples <- jags(dlist, inits=myinits, parameters,
                model.file = RW,
                n.chains=1, n.iter=1000, n.burnin=500, n.thin=1, DIC=T) 

Solution

  • So the workaround here is a little bit easier than my standard nested indexing because you are always missing data on the right side of your matrices (ie.,once data is NA it is NA for the rest of the column). As such, instead of needing to apply nested indexing within the loops you can just apply it to the second for loop (I'm using runjags here as that is what I am most familiar with).

    # data
    ex_expectancy <- structure(list(`1` = c(9L, 5L, 1L), `2` = c(5L, 6L, 1L), `3` = c(2L, 7L, 4L), `4` = c(3L, 6L, 2L), `5` = c(9L, 6L, 4L), `6` = c(9L, 7L, 1L), `7` = c(3L, 5L, 5L), `8` = c(8L, 5L, 1L), `9` = c(10L, 5L, NA), `10` = c(9L, NA, NA), `11` = c(2L, NA, NA), `12` = c(3L,NA, NA), `13` = c(3L, NA, NA), `14` = c(4L, NA, NA)), row.names = c(NA,-3L), class = c("data.table", "data.frame"))
    ex_shock <- structure(list(`1` = c(0L, 1L, 1L), `2` = c(0L, 1L, 0L), `3` = c(1L, 0L, 1L), `4` = c(1L, 0L, 1L), `5` = c(0L, 1L, 1L), `6` = c(0L, 0L, 0L), `7` = c(1L, 0L, 1L), `8` = c(1L, 1L, 1L), `9` = c(0L,1L, NA), `10` = c(1L, NA, NA), `11` = c(1L, NA, NA), `12` = c(0L, NA, NA), `13` = c(1L, NA, NA), `14` = c(0L, NA, NA)), row.names = c(NA,-3L), class = c("data.table", "data.frame"))
    
    
    v <- matrix(NA, nrow=3,ncol=14)
    v[,1] <- 0
    
    dlist <- list(
      NSubjects = 3,
      Ntrials = 14 - rowSums(is.na(ex_shock)),
      maxTrials = 14,
      expectancy = as.matrix(ex_expectancy), 
      shock = as.matrix(ex_shock),
      v = v
    )
    
    myinits <-  list(list(
      alpha = runif (3,0,1)))
    parameters <- c('alpha','v','predk','scale','c','tau')
    
    
    {sink("model.txt")
    cat("
    model{
        for (i in 1:NSubjects){
          for (j in 2:Ntrials[i]){
            expectancy[i,j] ~ dnorm (scale[i] * v[i,j] + c[i],tau[i,j])
            # posteiror predictive
            predk[i,j] ~ dnorm (scale [i] * v[i,j] + c[i],tau[i,j])
            pe[i,j-1] <- shock[i,j-1] - v[i,j-1]
            v[i,j] <- v[i,j-1] + alpha[i]  * pe[i,j-1]
          }
        }
        # priors
        for (i in 1: NSubjects){
          alpha[i] ~ dunif (0,1)
          scale[i] ~ dunif (0,10)
          c[i] ~ dunif (0,5)
          
          for (j in 1:maxTrials){
            sigma[i,j] ~ dunif (0,5)
            tau[i,j] <- 1/pow(sigma [i,j],2)
          }}
      }"
      ,fill = TRUE)
    }
    sink()
    
    
    
    library(runjags)
    
    
    samples <- run.jags("model.txt", monitor = parameters, data = dlist,
                        n.chains = 2,sample = 10000, burnin = 5000,
                        thin = 1)
    

    Basically Ntrials becomes a vector of length NSubjects. By applying that small change the model will compile and run. This does not, however, address any potential fitting issues with the model. As I'm unsure what you are actually fitting, I do not know if the model is correct as specified. Looking at the output of the mcmc it looks as if something odd is still going on (some parts of predk and tau are NA).

    library(coda)
     my_mcmc <- as.matrix(as.mcmc.list(samples))
    
    round(my_mcmc[1,],2)
      alpha[1]    alpha[2]    alpha[3]      v[1,1]      v[2,1]      v[3,1] 
           0.23        0.13        0.75        0.48        0.32        0.12 
         v[1,2]      v[2,2]      v[3,2]      v[1,3]      v[2,3]      v[3,3] 
           0.23        0.09        0.05        0.08        0.29        7.60 
         v[1,4]      v[2,4]      v[3,4]      v[1,5]      v[2,5]      v[3,5] 
           0.05        0.11        0.10        0.15        0.62        0.00 
         v[1,6]      v[2,6]      v[3,6]      v[1,7]      v[2,7]      v[3,7] 
           0.00        0.00        0.00        0.13        0.75        0.00 
         v[1,8]      v[2,8]      v[3,8]      v[1,9]      v[2,9]     v[1,10] 
           0.24        0.19        0.23        0.21        0.80        0.41 
        v[1,11]     v[1,12]     v[1,13]     v[1,14]  predk[1,2]  predk[2,2] 
           0.18        0.95        0.32        0.29          NA          NA 
     predk[3,2]  predk[1,3]  predk[2,3]  predk[3,3]  predk[1,4]  predk[2,4] 
             NA        5.30       -0.48        1.83        2.01        7.30 
     predk[3,4]  predk[1,5]  predk[2,5]  predk[3,5]  predk[1,6]  predk[2,6] 
           0.57        2.77        6.49        2.37        2.82        5.76 
     predk[3,6]  predk[1,7]  predk[2,7]  predk[3,7]  predk[1,8]  predk[2,8] 
           6.23       -4.78        7.10        5.12        3.10        0.95 
     predk[3,8]  predk[1,9]  predk[2,9] predk[1,10] predk[1,11] predk[1,12] 
          -0.34       -0.31       10.04        4.13        2.60        9.53 
    predk[1,13] predk[1,14]    scale[1]    scale[2]    scale[3]        c[1] 
             NA       10.83          NA          NA        1.48        2.46 
           c[2]        c[3]    tau[1,1]    tau[2,1]    tau[3,1]    tau[1,2] 
           4.75        1.36          NA          NA       10.40          NA 
       tau[2,2]    tau[3,2]    tau[1,3]    tau[2,3]    tau[3,3]    tau[1,4] 
             NA        2.61          NA          NA       -2.68          NA 
       tau[2,4]    tau[3,4]    tau[1,5]    tau[2,5]    tau[3,5]    tau[1,6] 
             NA        1.35        8.25        1.19        0.14        0.11 
       tau[2,6]    tau[3,6]    tau[1,7]    tau[2,7]    tau[3,7]    tau[1,8] 
           0.08        0.10        0.09        0.22        0.76        4.85 
       tau[2,8]    tau[3,8]    tau[1,9]    tau[2,9]    tau[3,9]   tau[1,10] 
           0.70       16.36        1.66        1.59        0.05        3.98 
      tau[2,10]   tau[3,10]   tau[1,11]   tau[2,11]   tau[3,11]   tau[1,12] 
           0.07        0.05       53.40        0.08        9.30        0.08 
      tau[2,12]   tau[3,12]   tau[1,13]   tau[2,13]   tau[3,13]   tau[1,14] 
           0.18        0.10        0.12        0.87        0.08        0.09 
      tau[2,14]   tau[3,14] 
           5.40        0.04 ```