Search code examples
stan

Extracting Mean Parameter Estimates from Stan Output Table


I understand how to extract chains from a Stan model but I was wondering if there was any quick way to extract the values displayed on the default Stan output table.

Here is some toy data

# simulate linear model
a <- 3 # intercept
b <- 2 # slope

# we can have both the predictor and the noise vary
x <- rnorm(28, 0, 1)
eps <- rnorm(28, 0, 2)
y <- a + b*x + eps

Which when we analyse

mod <- lm(y ~ x, df)

We can extract coefficients from

mod$coefficients

# (Intercept)           x 
#    3.355967    2.151597 

I wondered if there is any way to do the equivalent with a Stan output table

# Step 1: Make List
data_reg <- list(N = 28, x = x, y = y)

# Step 2: Create Model String
write("
      data {
      int<lower=0> N;
      vector[N] x;
      vector[N] y;
      }
      parameters {
      real alpha;
      real beta;
      real<lower=0> sigma;
      }
      model {
      vector[N] mu;
      sigma ~ cauchy(0, 2);
      beta ~ normal(0,10);
      alpha ~ normal(0,100);
      for ( i in 1:N ) {
      mu[i] = alpha + beta * x[i];
      }
      y ~ normal(mu, sigma);
      }
      ", file = "temp.stan")


# Step 3: Generate MCMC Chains
fit1 <- stan(file = "temp.stan",    
             data = data_reg,        
             chains = 2,             
             warmup = 1000,          
             iter = 2000,            
             cores = 2,               
             refresh = 1000) 

Now, when we call the model

fit1

# Output
#         mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
# alpha   3.33    0.01 0.40   2.57   3.06   3.33   3.59   4.13  1229    1
# beta    2.14    0.01 0.40   1.37   1.89   2.14   2.40   2.98  1470    1
# sigma   1.92    0.01 0.27   1.45   1.71   1.90   2.09   2.51  1211    1
# lp__  -31.92    0.05 1.30 -35.27 -32.50 -31.63 -30.96 -30.43   769    1

Is there any way to index and extract elements from the Output table displayed above?


Solution

  • If you only want means, then the get_posterior_mean function will work. Otherwise, you assign the result of print(fit1) or summary(print1) to an object, you can extract stuff from that object, but it is probably better to just do as.matrix(fit1) or as.data.frame(fit1) and calculate whatever you want yourself on the resulting columns.