Search code examples
rlogistic-regressionbayesianmultinomialrjags

Bayesian Multinomial Regression using rjags package


I am trying to fit a multinomial logistic regression model using rjags. The outcome is a categorical (nominal) variable (Outcome) with 3 levels, and the explanatory variables are Age (continuous) and Group (categorical with 3 levels). In doing so, I would like to obtain the Posterior means and 95% quantile-based regions for Age and Group.

I am not really great at for loop which I think is the reason why my written code for the model isn't working properly.

My beta priors follow a Normal distribution, βj ∼ Normal(0,100) for j ∈ {0, 1, 2}.

Reproducible R code

library(rjags)

set.seed(1)
data <- data.frame(Age = round(runif(119, min = 1, max = 18)),
                   Group = c(rep("pink", 20), rep("blue", 18), rep("yellow", 81)), 
                   Outcome = c(rep("A", 45), rep("B", 19), rep("C", 55)))

X <- as.matrix(data[,c("Age", "Group")]) 
J <- ncol(X)
N <- nrow(X)

## Step 1: Specify model
cat("
model {
for (i in 1:N){

    ##Sampling model
    yvec[i] ~ dmulti(p[i,1:J], 1)
    #yvec[i] ~ dcat(p[i, 1:J])  # alternative
    for (j in 1:J){
      log(q[i,j]) <- beta0 + beta1*X[i,1] + beta2*X[i,2] 
      p[i,j] <- q[i,j]/sum(q[i,1:J])  
    } 
    
    ##Priors
    beta0 ~ dnorm(0, 0.001)
    beta1 ~ dnorm(0, 0.001)
    beta2 ~ dnorm(0, 0.001)
}
}",
file="model.txt")

##Step 2: Specify data list 
dat.list <- list(yvec = data$Outcome, X=X, J=J, N=N) 

## Step 3: Compile and adapt model in JAGS 
jagsModel<-jags.model(file = "model.txt",
                      data = dat.list,
                      n.chains = 3,
                      n.adapt = 3000
)

Error message:

enter image description here

Sources I have been looking at for help:

http://people.bu.edu/dietze/Bayes2018/Lesson21_GLM.pdf

Dirichlet Multinomial model in JAGS with categorical X

Reference from http://www.stats.ox.ac.uk/~nicholls/MScMCMC15/jags_user_manual.pdf, page 31

enter image description here

I have just started to learn how to use the rjags package so any hint/explanation and link to relevant sources would be greatly appreciated!


Solution

  • I will include an approach to your issue. I have taken the same priors you defined for coefficients. I only need to mention that as you have a factor in Group I will use one of its levels as reference (in this case pink) so its effect will be taken into account by the constant in the model. Next the code:

    library(rjags)
    #Data
    set.seed(1)
    data <- data.frame(Age = round(runif(119, min = 1, max = 18)),
                       Group = c(rep("pink", 20), rep("blue", 18), rep("yellow", 81)), 
                       Outcome = c(rep("A", 45), rep("B", 19), rep("C", 55)))
    
    #Input Values we will avoid pink because it is used as reference level
    #so constant absorbs the effect of that level
    r1 <- as.numeric(data$Group=='pink')
    r2 <- as.numeric(data$Group=='blue')
    r3 <- as.numeric(data$Group=='yellow')
    age <- data$Age
    #Output 2 and 3
    o1 <- as.numeric(data$Outcome=='A')
    o2 <- as.numeric(data$Outcome=='B')
    o3 <- as.numeric(data$Outcome=='C')
    #Dim, all have the same length
    N <- length(r2)
    
    ## Step 1: Specify model
    
    model.string <- "
    model{
    for (i in 1:N){ 
    
    ## outcome levels B, C
    o1[i] ~ dbern(pi1[i])
    o2[i] ~ dbern(pi2[i]) 
    o3[i] ~ dbern(pi3[i]) 
    
    ## predictors
    logit(pi1[i]) <- b1+b2*age[i]+b3*r2[i]+b4*r3[i]
    logit(pi2[i]) <- b1+b2*age[i]+b3*r2[i]+b4*r3[i]
    logit(pi3[i]) <- b1+b2*age[i]+b3*r2[i]+b4*r3[i]
    
    } 
    ## priors
    b1 ~ dnorm(0, 0.001)
    b2 ~ dnorm(0, 0.001)
    b3 ~ dnorm(0, 0.001)
    b4 ~ dnorm(0, 0.001)
    }
    "
    #Model
    model.spec<-textConnection(model.string)
    
    ## fit model w JAGS
    jags <- jags.model(model.spec,
                       data = list('r2'=r2,'r3'=r3,
                                   'o1'=o1,'o2'=o2,'o3'=o3,
                                   'age'=age,'N'=N),
                       n.chains=3,
                       n.adapt=3000)
    
    #Update the model
    #Update
    update(jags, n.iter=1000,progress.bar = 'none')
    #Sampling
    results <- coda.samples(jags,variable.names=c("b1","b2","b3","b4"),n.iter=1000,
                            progress.bar = 'none')
    #Results
    Res <- do.call(rbind.data.frame, results)
    

    With the results of chains for parameters saved in Res, you can compute posterior media and credible intervals using next code:

    #Posterior means
    apply(Res,2,mean)
    
             b1          b2          b3          b4 
    -0.79447801  0.00168827  0.07240954  0.08650250
    
    #Lower CI limit
    apply(Res,2,quantile,prob=0.05)
    
             b1          b2          b3          b4 
    -1.45918662 -0.03960765 -0.61027923 -0.42674155
    
    #Upper CI limit
    apply(Res,2,quantile,prob=0.95)
    
             b1          b2          b3          b4 
    -0.13005617  0.04013478  0.72852243  0.61216838 
    

    The b parameters belong to the each of the variables considered (age and the levels of Group). Final values could change because of the mixed chains!