Search code examples
rbayesianstanrstan

Plotting interaction effects in Bayesian models (using rstanarm)


I'm trying to show how the effect of one variables changes with the values of another variable in a Bayesian linear model in rstanarm(). I am able to fit the model and take draws from the posterior to look at the estimates for each parameter, but it's not clear how to give some sort of plot of the effects of one variable in the interaction as the other changes and the associated uncertainty (i.e. a marginal effects plot). Below is my attempt:

library(rstanarm)

# Set Seed
set.seed(1)

# Generate fake data
w1 <- rbeta(n = 50, shape1 = 2, shape2 = 1.5)
w2 <- rbeta(n = 50, shape1 = 3, shape2 = 2.5)

dat <- data.frame(y = log(w1 / (1-w1)),
                  x = log(w2 / (1-w2)),
                  z = seq(1:50))

# Fit linear regression without an intercept:
m1 <- rstanarm::stan_glm(y ~ 0 + x*z, 
                         data = dat,
                         family = gaussian(),
                         algorithm = "sampling",
                         chains = 4,
                         seed = 123,
                         )


# Create data sets with low values and high values of one of the predictors
dat_lowx <- dat
dat_lowx$x <- 0

dat_highx <- dat
dat_highx$x <- 5

out_low <- rstanarm::posterior_predict(object = m1,
                                   newdata = dat_lowx)

out_high <- rstanarm::posterior_predict(object = m1,
                                        newdata = dat_highx)

# Calculate differences in posterior predictions
mfx <- out_high - out_low

# Somehow get the coefficients for the other predictor?

Solution

  • In this (linear, Gaussian, identity link, no intercept) case,

    mu = beta_x * x + beta_z * z + beta_xz * x * z 
       = (beta_x + beta_xz * z) * x 
       = (beta_z + beta_xz * x) * z
    

    So, to plot the marginal effect of x or z, you just need an appropriate range of each and the posterior distribution of the coefficients, which you can obtain via

    post <- as.data.frame(m1)
    

    Then

    dmu_dx <- post[ , 1] + post[ , 3] %*% t(sort(dat$z))
    dmu_dz <- post[ , 2] + post[ , 3] %*% t(sort(dat$x))
    

    And you can then estimate a single marginal effect for each observation in your data by using something like the below, which calculated the effect of x on mu for each observation in your data and the effect of z on mu for each observation.

    colnames(dmu_dx) <- round(sort(dat$x), digits = 1)
    colnames(dmu_dz) <- dat$z
    bayesplot::mcmc_intervals(dmu_dz)
    bayesplot::mcmc_intervals(dmu_dx)
    

    Note that the column names are simply the observations in this case.