This vignette illustrates how to use bayesEP to fit a
hierarchical Bayesian model using expectation propagation (EP). EP is a
message-passing algorithm that approximates the posterior distribution
of shared (hierarchical) parameters by iteratively combining information
from independently fitted subsets of the data. This makes it possible to
fit large hierarchical models without needing to run MCMC on the full
dataset.
For a complete description of EP for approximating Bayesian models, see Vehtari et al. (2020).
The basic workflow is: 1. Define a model with
group-specific and shared (hierarchical) parameters. 2. Write a
fitting function that can fit the model to a subset of the
data, optionally using a cavity distribution as the prior on shared
parameters. 3. Call fit_ep(), which
handles the distributing data to each site, running the core EP loop,
computing the cavity distributions, and monitoring convergence.
We will illustrate with a simple hierarchical normal model. Each observation belongs to a group \(c[i] \in \{ 1, \dots, C \}\). Group-specific means are drawn from a shared normal distribution: \[ \beta_c \sim N\left(\mu_{\beta}, \sigma_{\beta}^2\right), \qquad y_i \sim N\left( \beta_{c[i]}, \sigma^2 \right). \]
simulate_data <- function(C, N, beta_mu, beta_sigma, sigma, seed = 10016) {
set.seed(seed)
# Group-level hierarchical means
true_beta <- rnorm(C, beta_mu, beta_sigma)
# Observations
data <- data.frame(c = sample(1:C, size = N, replace = TRUE))
data$y <- rnorm(N, mean = true_beta[data$c], sigma)
# True values of shared parameters
true_phi <- c(beta_mu, log(beta_sigma), log(sigma))
list(
data = data,
true_phi = true_phi,
true_beta = true_beta
)
}We simulate 300 observations across 10 groups:
ggplot(sim$data, aes(x = factor(c), y = y)) +
geom_jitter(width = 0.15, alpha = 0.5) +
labs(x = "Group", y = expression(y))The Stan model is written to support both the full joint model (with
original priors) and the EP tilted distribution fit (with a cavity
distribution as the prior on \(\phi\)).
The use_cavity flag switches between the two options. When
use_cavity = 1, the original priors on the shared
parameters are replaced by a multivariate normal distribution.
data {
// Number of groups
int<lower=0> C;
// Number of observations
int<lower=0> N;
vector[N] y;
array[N] int<lower=1, upper=C> group_index;
// Flag: 0 = use original priors, 1 = use cavity distribution
int<lower=0, upper=1> use_cavity;
// Cavity distribution parameters
vector[3] cavity_mu;
matrix[3, 3] cavity_Sigma;
}
parameters {
// Local parameters
vector[C] raw_beta;
// Shared parameters
real beta_mu;
real log_beta_sigma;
real log_sigma;
}
transformed parameters {
// Apply transformations
real beta_sigma = exp(log_beta_sigma);
real sigma = exp(log_sigma);
// Transform from non-centered parameterization
vector[C] beta = beta_mu + beta_sigma * raw_beta;
// Collect shared parameters into a single vector
vector[3] phi;
phi[1] = beta_mu;
phi[2] = log_beta_sigma;
phi[3] = log_sigma;
}
model {
// Local group-specific parameters
raw_beta ~ std_normal();
// Shared parameters
if(use_cavity == 1) {
phi ~ multi_normal(cavity_mu, cavity_Sigma);
}
else {
beta_mu ~ std_normal();
log_beta_sigma ~ std_normal();
log_sigma ~ std_normal();
}
// Likelihood
for(n in 1:N) {
y[n] ~ normal(beta[group_index[n]], sigma);
}
}bayesEP requires the user to supply a
fit_model function that knows how to fit the model to a
data frame (which may be a subsetted version of the full data). This
function must accept the following arguments:
data: a data frame containing the observations for one
or more groups.use_cavity: logical flag indicating whether to use the
cavity distribution as the prior for the shared parameters.cavity_mu: mean vector of the cavity distribution
(length \(d\)).cavity_Sigma: covariance matrix of the cavity
distribution (\(d \times d\)).On success, the function must return a list containing at minimum a
phi element: a matrix of posterior draws for the shared
parameters (rows = draws, columns = parameters). On failure (e.g., MCMC
sampling fails), it should return NULL. Any additional
elements, such as the full Stan fit object, are passed through and
stored in the final EP results.
fit_model <- function(data, use_cavity = FALSE, cavity_mu = NULL, cavity_Sigma = NULL) {
if (is.null(cavity_mu)) {
cavity_mu <- rep(0, 3)
cavity_Sigma <- diag(3)
}
group_index <- data.frame(c = sort(unique(data$c)))
group_index$ci <- 1:nrow(group_index)
data$ci <- group_index$ci[base::match(data$c, group_index$c)]
stan_data <- list(
N = nrow(data),
y = data$y,
group_index = data$ci,
C = nrow(group_index),
use_cavity = as.integer(use_cavity),
cavity_mu = cavity_mu,
cavity_Sigma = cavity_Sigma
)
samples <- model$sample(
data = stan_data,
chains = 4,
parallel_chains = 4,
iter_warmup = ifelse(use_cavity, 1e3, 2e3),
iter_sampling = ifelse(use_cavity, 1e3, 2e3),
adapt_delta = 0.99,
max_treedepth = 12,
refresh = 0,
show_messages = FALSE,
show_exceptions = FALSE,
seed = 10016
)
list(
phi = posterior::as_draws_matrix(samples$draws("phi")),
samples = samples,
group_index = group_index
)
}Before running EP, we fit the full joint model as a reference. This
uses all the data and the original priors for the shared parameters
(use_cavity = FALSE):
Check diagnostics:
full_result$samples$cmdstan_diagnose()
#> Checking sampler transitions treedepth.
#> Treedepth satisfactory for all transitions.
#>
#> Checking sampler transitions for divergences.
#> No divergent transitions found.
#>
#> Checking E-BFMI - sampler transitions HMC potential energy.
#> E-BFMI satisfactory.
#>
#> Rank-normalized split effective sample size satisfactory for all parameters.
#>
#> Rank-normalized split R-hat values satisfactory for all parameters.
#>
#> Processing complete, no problems detected.Posterior marginal densities for the shared parameters:
color_scheme <- c("Full MCMC" = "steelblue", "EP" = "firebrick", "Truth" = "darkgray")
phi_df <- as.data.frame(full_result$phi)
make_phi_plot <- function(phi_df, col_name, true_val, xlab) {
ggplot(phi_df, aes(x = .data[[col_name]])) +
geom_density(color = color_scheme["Full MCMC"]) +
geom_vline(xintercept = true_val, linetype = "dashed") +
labs(x = xlab)
}
gridExtra::grid.arrange(
make_phi_plot(phi_df, "phi[1]", sim$true_phi[1], expression(mu[beta])),
make_phi_plot(phi_df, "phi[2]", sim$true_phi[2], expression(log(sigma[beta]))),
make_phi_plot(phi_df, "phi[3]", sim$true_phi[3], expression(log[sigma])),
nrow = 1
)Now we fit the same model using EP via fit_ep(). The key
arguments are:
data: the full dataset.group_column: name of column
identifying the groups.K: number of EP sites. Each site
handles approximately C/K groups.d: dimension of the shared parameter
vector \(\phi\).fit_model: the fitting function
defined above.max_iter: maximum number of EP
iterations.conv_tol: convergence tolerance on the
natural parameter updates.We start with K = 2 sites, so that each site handles 10
of the 20 groups:
ep_result <- fit_ep(
data = sim$data,
group_column = "c",
K = 2,
d = 3,
fit_model = fit_model,
max_iter = 25,
conv_tol = 10,
verbose = TRUE,
save_all_tilted_fits = TRUE
)
#> --- Starting iteration 1 (damping = 0.5000) ---
#> > Site: 1 2
#> --- Finished iteration 1 (max_delta_Q = 605.1, max_delta_r = 22.2) ---
#> --- Starting iteration 2 (damping = 0.2669) ---
#> > Site: 1 2
#> --- Finished iteration 2 (max_delta_Q = 285.4, max_delta_r = 11.5) ---
#> --- Starting iteration 3 (damping = 0.2149) ---
#> > Site: 1 2
#> --- Finished iteration 3 (max_delta_Q = 201.2, max_delta_r = 10.3) ---
#> --- Starting iteration 4 (damping = 0.2033) ---
#> > Site: 1 2
#> --- Finished iteration 4 (max_delta_Q = 135.8, max_delta_r = 6.1) ---
#> --- Starting iteration 5 (damping = 0.2007) ---
#> > Site: 1 2
#> --- Finished iteration 5 (max_delta_Q = 109.3, max_delta_r = 3.5) ---
#> --- Starting iteration 6 (damping = 0.2002) ---
#> > Site: 1 2
#> --- Finished iteration 6 (max_delta_Q = 134.0, max_delta_r = 7.2) ---
#> --- Starting iteration 7 (damping = 0.2000) ---
#> > Site: 1 2
#> --- Finished iteration 7 (max_delta_Q = 133.0, max_delta_r = 3.9) ---
#> --- Starting iteration 8 (damping = 0.2000) ---
#> > Site: 1 2
#> --- Finished iteration 8 (max_delta_Q = 116.8, max_delta_r = 4.2) ---
#> --- Starting iteration 9 (damping = 0.2000) ---
#> > Site: 1 2
#> --- Finished iteration 9 (max_delta_Q = 34.8, max_delta_r = 2.1) ---
#> --- Starting iteration 10 (damping = 0.2000) ---
#> > Site: 1 2
#> --- Finished iteration 10 (max_delta_Q = 73.1, max_delta_r = 1.6) ---
#> --- Starting iteration 11 (damping = 0.2000) ---
#> > Site: 1 2
#> --- Finished iteration 11 (max_delta_Q = 30.2, max_delta_r = 2.7) ---
#> --- Starting iteration 12 (damping = 0.2000) ---
#> > Site: 1 2
#> --- Finished iteration 12 (max_delta_Q = 8.0, max_delta_r = 1.4) ---
#> === Converged (conv_tol = 10.0) ===The EP approximation to the posterior of the shared parameters is a
multivariate normal with mean ep_results$mu and covariance
ep_results$Sigma:
shared_param_names <- c("mu_beta", "log_beta_sigma", "log_sigma")
f <- scales::number_format(accuracy = 0.01)
data.frame(
parameter = shared_param_names,
true = f(sim$true_phi),
full_mean = f(unname(colMeans(as.matrix(full_result$phi)))),
ep_mean = f(ep_result$mu),
full_sd = f(unname(apply(as.matrix(full_result$phi), 2, sd))),
ep_sd = f(unname(sqrt(diag(ep_result$Sigma))))
) |>
knitr::kable()| parameter | true | full_mean | ep_mean | full_sd | ep_sd |
|---|---|---|---|---|---|
| mu_beta | 0.00 | -0.04 | -0.06 | 0.43 | 0.41 |
| log_beta_sigma | 0.00 | 0.33 | 0.33 | 0.24 | 0.24 |
| log_sigma | 0.00 | 0.04 | 0.04 | 0.04 | 0.04 |
We can overlay the EP Gaussian approximation on the full MCMC posterior densities to visually assess how well EP captures the marginal posteriors:
make_comparison_plot <- function(phi_df, col_name, ep_mu, ep_sd, true_val, xlab) {
ggplot(phi_df, aes(x = .data[[col_name]])) +
geom_density(aes(color = "Full MCMC")) +
geom_function(aes(color = "EP"), fun = function(x) dnorm(x, ep_mu, ep_sd)) +
geom_vline(xintercept = true_val, linetype = "dashed", alpha = 0.5) +
scale_color_manual(values = color_scheme) +
labs(x = xlab, y = NULL, color = NULL) +
theme(legend.position = "bottom")
}
grid.arrange(
make_comparison_plot(phi_df, "phi[1]", ep_result$mu[1], sqrt(ep_result$Sigma[1, 1]),
sim$true_phi[1], expression(mu[beta])),
make_comparison_plot(phi_df, "phi[2]", ep_result$mu[2], sqrt(ep_result$Sigma[2, 2]),
sim$true_phi[2], expression(log(sigma[beta]))),
make_comparison_plot(phi_df, "phi[3]", ep_result$mu[3], sqrt(ep_result$Sigma[3, 3]),
sim$true_phi[3], expression(log(sigma))),
nrow = 1
)We can also extract the tilted fits from the final iteration for each site, which contain posterior draws for the group-specific parameters \(\beta_c\).
summarize_beta <- function(x) {
s <- posterior::summarize_draws(x$fit$samples$draws("beta"))
s$variable <- NULL
s$group <- x$fit$group_index$c
s
}
final_iteration <- length(ep_result$tilted_fits)
ep_beta <- do.call(rbind, lapply(ep_result$tilted_fits[[final_iteration]], summarize_beta))
full_beta <- full_result$samples$summary("beta")
full_beta$group <- 1:nrow(full_beta)
true_beta <- data.frame(
group = 1:length(sim$true_beta),
beta = sim$true_beta
)ggplot(ep_beta, aes(x = median, y = factor(group))) +
geom_point(
aes(x = beta, color = "Truth"),
data = true_beta
) +
geom_point(
aes(color = "EP"),
position = position_nudge(y = 0.1)
) +
geom_errorbarh(
aes(xmin = q5, xmax = q95, color = "EP"),
width = 0,
position = position_nudge(y = 0.1)
) +
geom_point(
aes(color = "Full MCMC"),
data = full_beta,
position = position_nudge(y = -0.1)
) +
geom_errorbarh(
aes(xmin = q5, xmax = q95, color = "Full MCMC"),
data = full_beta,
width = 0,
position = position_nudge(y = -0.1)
) +
scale_color_manual(values = color_scheme) +
labs(x = expression(beta[c]), y = "Group", color = NULL)