Getting started with bayesEP

library(bayesEP)
library(cmdstanr)
library(posterior)
library(ggplot2)
library(gridExtra)

This vignette illustrates how to use bayesEP to fit a hierarchical Bayesian model using expectation propagation (EP). EP is a message-passing algorithm that approximates the posterior distribution of shared (hierarchical) parameters by iteratively combining information from independently fitted subsets of the data. This makes it possible to fit large hierarchical models without needing to run MCMC on the full dataset.

For a complete description of EP for approximating Bayesian models, see Vehtari et al. (2020).

The basic workflow is: 1. Define a model with group-specific and shared (hierarchical) parameters. 2. Write a fitting function that can fit the model to a subset of the data, optionally using a cavity distribution as the prior on shared parameters. 3. Call fit_ep(), which handles the distributing data to each site, running the core EP loop, computing the cavity distributions, and monitoring convergence.

Simulating data

We will illustrate with a simple hierarchical normal model. Each observation belongs to a group \(c[i] \in \{ 1, \dots, C \}\). Group-specific means are drawn from a shared normal distribution: \[ \beta_c \sim N\left(\mu_{\beta}, \sigma_{\beta}^2\right), \qquad y_i \sim N\left( \beta_{c[i]}, \sigma^2 \right). \]

simulate_data <- function(C, N, beta_mu, beta_sigma, sigma, seed = 10016) {
  set.seed(seed)

  # Group-level hierarchical means
  true_beta <- rnorm(C, beta_mu, beta_sigma)
  
  # Observations
  data   <- data.frame(c = sample(1:C, size = N, replace = TRUE))
  data$y <- rnorm(N, mean = true_beta[data$c], sigma)
  
  # True values of shared parameters
  true_phi <- c(beta_mu, log(beta_sigma), log(sigma))

  list(
    data = data,
    true_phi = true_phi,
    true_beta = true_beta
  )
}

We simulate 300 observations across 10 groups:

sim <- simulate_data(
  C = 10, 
  N = 300,
  beta_mu = 0,
  beta_sigma = 1.0,
  sigma = 1.0
)
ggplot(sim$data, aes(x = factor(c), y = y)) +
  geom_jitter(width = 0.15, alpha = 0.5) +
  labs(x = "Group", y = expression(y))
Simulated observations by group.
Simulated observations by group.

Defining the model

Shared vs. group-specific parameters

The key distinction in the EP framework is between shared parameters and group-specific parameters. In our model: - Shared parameters \(\phi = (\mu_\beta, \log \sigma_\beta, \log \sigma)\) appear in every group’s likelihood. These are the parameters that EP approximates by passing messages between sites. The dimension of the shared parameters in this example is \(d = 3\). - Group-specific parameters \(\beta_c\) each appear in only one group, and are not included in the EP messages.

Note that shared parameters are defined on an unconstrained scale (that is, taking values in \(\mathbb{R}\)), so that the multivariate normal approximation used by EP is reasonable. For constrained parameters, this may require applying transformations. For example, parameters constrained to be positive (such as standard deviation or variance parameters) must be log-transformed and sampled on an unconstrained log-scale.

To complete the Bayesian model specification, we place normal priors on each shared parameter: \[ \begin{align} \mu_{\beta} &\sim N(0, 1), \\ \log(\sigma_{\beta}) &\sim N(0, 1), \\ \log(\sigma) &\sim N(0, 1). \end{align} \]

Stan model

The Stan model is written to support both the full joint model (with original priors) and the EP tilted distribution fit (with a cavity distribution as the prior on \(\phi\)). The use_cavity flag switches between the two options. When use_cavity = 1, the original priors on the shared parameters are replaced by a multivariate normal distribution.

data {
  // Number of groups
  int<lower=0> C;
  
  // Number of observations
  int<lower=0> N;
  
  vector[N] y;
  array[N] int<lower=1, upper=C> group_index;

  // Flag: 0 = use original priors, 1 = use cavity distribution
  int<lower=0, upper=1> use_cavity;

  // Cavity distribution parameters
  vector[3] cavity_mu;
  matrix[3, 3] cavity_Sigma;
}
parameters {
  // Local parameters
  vector[C] raw_beta;
  
  // Shared parameters
  real beta_mu;
  real log_beta_sigma;
  real log_sigma;
}
transformed parameters {
  // Apply transformations
  real beta_sigma = exp(log_beta_sigma);
  real sigma = exp(log_sigma);

  // Transform from non-centered parameterization
  vector[C] beta = beta_mu + beta_sigma * raw_beta;

  // Collect shared parameters into a single vector
  vector[3] phi;
  phi[1] = beta_mu;
  phi[2] = log_beta_sigma;
  phi[3] = log_sigma;
}
model {
  // Local group-specific parameters
  raw_beta ~ std_normal();

  // Shared parameters
  if(use_cavity == 1) {
    phi ~ multi_normal(cavity_mu, cavity_Sigma);
  }
  else {
    beta_mu ~ std_normal();
    log_beta_sigma ~ std_normal();
    log_sigma ~ std_normal();
  }

  // Likelihood
  for(n in 1:N) {
    y[n] ~ normal(beta[group_index[n]], sigma);
  }
}
model <- cmdstanr::cmdstan_model("model.stan")

Model fitting function

bayesEP requires the user to supply a fit_model function that knows how to fit the model to a data frame (which may be a subsetted version of the full data). This function must accept the following arguments:

On success, the function must return a list containing at minimum a phi element: a matrix of posterior draws for the shared parameters (rows = draws, columns = parameters). On failure (e.g., MCMC sampling fails), it should return NULL. Any additional elements, such as the full Stan fit object, are passed through and stored in the final EP results.

fit_model <- function(data, use_cavity = FALSE, cavity_mu = NULL, cavity_Sigma = NULL) {
  if (is.null(cavity_mu)) {
    cavity_mu <- rep(0, 3)
    cavity_Sigma <- diag(3)
  }

  group_index <- data.frame(c = sort(unique(data$c)))
  group_index$ci <- 1:nrow(group_index)
  
  data$ci <- group_index$ci[base::match(data$c, group_index$c)]
  
  stan_data <- list(
    N = nrow(data),
    y = data$y,
    group_index = data$ci,
    C = nrow(group_index),
    use_cavity = as.integer(use_cavity),
    cavity_mu = cavity_mu,
    cavity_Sigma = cavity_Sigma
  )

  samples <- model$sample(
    data = stan_data,
    chains = 4,
    parallel_chains = 4,
    iter_warmup = ifelse(use_cavity, 1e3, 2e3),
    iter_sampling = ifelse(use_cavity, 1e3, 2e3),
    adapt_delta = 0.99,
    max_treedepth = 12,
    refresh = 0,
    show_messages = FALSE,
    show_exceptions = FALSE,
    seed = 10016
  )

  list(
    phi = posterior::as_draws_matrix(samples$draws("phi")),
    samples = samples,
    group_index = group_index
  )
}

Fitting the full joint model

Before running EP, we fit the full joint model as a reference. This uses all the data and the original priors for the shared parameters (use_cavity = FALSE):

full_result <- fit_model(sim$data, use_cavity = FALSE)

Check diagnostics:

full_result$samples$cmdstan_diagnose()
#> Checking sampler transitions treedepth.
#> Treedepth satisfactory for all transitions.
#> 
#> Checking sampler transitions for divergences.
#> No divergent transitions found.
#> 
#> Checking E-BFMI - sampler transitions HMC potential energy.
#> E-BFMI satisfactory.
#> 
#> Rank-normalized split effective sample size satisfactory for all parameters.
#> 
#> Rank-normalized split R-hat values satisfactory for all parameters.
#> 
#> Processing complete, no problems detected.

Posterior marginal densities for the shared parameters:

color_scheme <- c("Full MCMC" = "steelblue", "EP" = "firebrick", "Truth" = "darkgray")
phi_df <- as.data.frame(full_result$phi)

make_phi_plot <- function(phi_df, col_name, true_val, xlab) {
  ggplot(phi_df, aes(x = .data[[col_name]])) + 
    geom_density(color = color_scheme["Full MCMC"]) + 
    geom_vline(xintercept = true_val, linetype = "dashed") +
    labs(x = xlab)
}

gridExtra::grid.arrange(
  make_phi_plot(phi_df, "phi[1]", sim$true_phi[1], expression(mu[beta])),
  make_phi_plot(phi_df, "phi[2]", sim$true_phi[2], expression(log(sigma[beta]))),
  make_phi_plot(phi_df, "phi[3]", sim$true_phi[3], expression(log[sigma])),
  nrow = 1 
)
Full joint model posterior densities (blue). Vertical dashed lines show true parameter values.
Full joint model posterior densities (blue). Vertical dashed lines show true parameter values.

Fitting with expectation propagation

Now we fit the same model using EP via fit_ep(). The key arguments are:

We start with K = 2 sites, so that each site handles 10 of the 20 groups:

ep_result <- fit_ep(
  data = sim$data,
  group_column = "c",
  K = 2,
  d = 3,
  fit_model = fit_model,
  max_iter = 25,
  conv_tol = 10,
  verbose = TRUE, 
  save_all_tilted_fits = TRUE
)
#> --- Starting iteration 1 (damping = 0.5000) ---
#>     > Site: 1 2 
#> --- Finished iteration 1 (max_delta_Q = 605.1, max_delta_r = 22.2) ---
#> --- Starting iteration 2 (damping = 0.2669) ---
#>     > Site: 1 2 
#> --- Finished iteration 2 (max_delta_Q = 285.4, max_delta_r = 11.5) ---
#> --- Starting iteration 3 (damping = 0.2149) ---
#>     > Site: 1 2 
#> --- Finished iteration 3 (max_delta_Q = 201.2, max_delta_r = 10.3) ---
#> --- Starting iteration 4 (damping = 0.2033) ---
#>     > Site: 1 2 
#> --- Finished iteration 4 (max_delta_Q = 135.8, max_delta_r = 6.1) ---
#> --- Starting iteration 5 (damping = 0.2007) ---
#>     > Site: 1 2 
#> --- Finished iteration 5 (max_delta_Q = 109.3, max_delta_r = 3.5) ---
#> --- Starting iteration 6 (damping = 0.2002) ---
#>     > Site: 1 2 
#> --- Finished iteration 6 (max_delta_Q = 134.0, max_delta_r = 7.2) ---
#> --- Starting iteration 7 (damping = 0.2000) ---
#>     > Site: 1 2 
#> --- Finished iteration 7 (max_delta_Q = 133.0, max_delta_r = 3.9) ---
#> --- Starting iteration 8 (damping = 0.2000) ---
#>     > Site: 1 2 
#> --- Finished iteration 8 (max_delta_Q = 116.8, max_delta_r = 4.2) ---
#> --- Starting iteration 9 (damping = 0.2000) ---
#>     > Site: 1 2 
#> --- Finished iteration 9 (max_delta_Q = 34.8, max_delta_r = 2.1) ---
#> --- Starting iteration 10 (damping = 0.2000) ---
#>     > Site: 1 2 
#> --- Finished iteration 10 (max_delta_Q = 73.1, max_delta_r = 1.6) ---
#> --- Starting iteration 11 (damping = 0.2000) ---
#>     > Site: 1 2 
#> --- Finished iteration 11 (max_delta_Q = 30.2, max_delta_r = 2.7) ---
#> --- Starting iteration 12 (damping = 0.2000) ---
#>     > Site: 1 2 
#> --- Finished iteration 12 (max_delta_Q = 8.0, max_delta_r = 1.4) ---
#> === Converged (conv_tol = 10.0) ===

The EP approximation to the posterior of the shared parameters is a multivariate normal with mean ep_results$mu and covariance ep_results$Sigma:

shared_param_names <- c("mu_beta", "log_beta_sigma", "log_sigma")
f <- scales::number_format(accuracy = 0.01)

data.frame(
  parameter = shared_param_names,
  true = f(sim$true_phi),
  full_mean = f(unname(colMeans(as.matrix(full_result$phi)))),
  ep_mean = f(ep_result$mu),
  full_sd = f(unname(apply(as.matrix(full_result$phi), 2, sd))),
  ep_sd = f(unname(sqrt(diag(ep_result$Sigma))))
) |>
  knitr::kable()
parameter true full_mean ep_mean full_sd ep_sd
mu_beta 0.00 -0.04 -0.06 0.43 0.41
log_beta_sigma 0.00 0.33 0.33 0.24 0.24
log_sigma 0.00 0.04 0.04 0.04 0.04

Comparing full MCMC and EP

We can overlay the EP Gaussian approximation on the full MCMC posterior densities to visually assess how well EP captures the marginal posteriors:

make_comparison_plot <- function(phi_df, col_name, ep_mu, ep_sd, true_val, xlab) {
  ggplot(phi_df, aes(x = .data[[col_name]])) +
    geom_density(aes(color = "Full MCMC")) +
    geom_function(aes(color = "EP"), fun = function(x) dnorm(x, ep_mu, ep_sd)) +
    geom_vline(xintercept = true_val, linetype = "dashed", alpha = 0.5) +
    scale_color_manual(values = color_scheme) +
    labs(x = xlab, y = NULL, color = NULL) +
    theme(legend.position = "bottom")
}

grid.arrange(
  make_comparison_plot(phi_df, "phi[1]", ep_result$mu[1], sqrt(ep_result$Sigma[1, 1]), 
                       sim$true_phi[1], expression(mu[beta])),
  make_comparison_plot(phi_df, "phi[2]", ep_result$mu[2], sqrt(ep_result$Sigma[2, 2]),
                       sim$true_phi[2], expression(log(sigma[beta]))),
  make_comparison_plot(phi_df, "phi[3]", ep_result$mu[3], sqrt(ep_result$Sigma[3, 3]),
                       sim$true_phi[3], expression(log(sigma))),
  nrow = 1
)
Comparison of full MCMC posterior densities (blue) and EP Gaussian approximations (red)
Comparison of full MCMC posterior densities (blue) and EP Gaussian approximations (red)

We can also extract the tilted fits from the final iteration for each site, which contain posterior draws for the group-specific parameters \(\beta_c\).


summarize_beta <- function(x) {
  s <- posterior::summarize_draws(x$fit$samples$draws("beta"))
  s$variable <- NULL
  s$group <- x$fit$group_index$c
  s
}
final_iteration <- length(ep_result$tilted_fits)
ep_beta <- do.call(rbind, lapply(ep_result$tilted_fits[[final_iteration]], summarize_beta))

full_beta <- full_result$samples$summary("beta")
full_beta$group <- 1:nrow(full_beta)

true_beta <- data.frame(
  group = 1:length(sim$true_beta), 
  beta = sim$true_beta
)
ggplot(ep_beta, aes(x = median, y = factor(group))) +
  geom_point(
    aes(x = beta, color = "Truth"), 
    data = true_beta
  ) +
  geom_point(
    aes(color = "EP"), 
    position = position_nudge(y = 0.1)
  ) +
  geom_errorbarh(
    aes(xmin = q5, xmax = q95, color = "EP"), 
    width = 0, 
    position = position_nudge(y = 0.1)
  ) +
  geom_point(
    aes(color = "Full MCMC"), 
    data = full_beta, 
    position = position_nudge(y = -0.1)
  ) +
  geom_errorbarh(
    aes(xmin = q5, xmax = q95, color = "Full MCMC"), 
    data = full_beta, 
    width = 0, 
    position = position_nudge(y = -0.1)
  ) +
  scale_color_manual(values = color_scheme) +
  labs(x = expression(beta[c]), y = "Group", color = NULL)

References