Skip to contents

If the code in this vignette has not been evaluated, a rendered version is available on the documentation site under ‘Articles’.

library(ggplot2)
library(dplyr)
library(sdmTMB)
library(rstan) # for plot() method
options(mc.cores = parallel::detectCores()) # use rstan parallel processing

Bayesian estimation is possible with sdmTMB by passing fitted models to tmbstan::tmbstan() (Monnahan & Kristensen 2018). All sampling is then done using Stan (Stan Development Team 2021), and output is returned as a stanfit object.

Why might you want to pass an sdmTMB model to Stan?

  • to obtain probabilistic inference on parameters
  • to avoid the Laplace approximation on the random effects
  • to robustly quantify uncertainty on derived quantities not already calculated in the model
  • in some cases, models that struggle to converge with maximum likelihood can be adequately sampled with MCMC given carefully chosen priors (e.g., Monnahan et al. 2021)

Simulating data

Here we will demonstrate using a simulated dataset.

set.seed(123)
predictor_dat <- data.frame(
  X = runif(500), Y = runif(500),
  a1 = rnorm(500)
)
mesh <- make_mesh(predictor_dat, xy_cols = c("X", "Y"), cutoff = 0.1)
# plot(mesh)
# mesh$mesh$n
sim_dat <- sdmTMB_simulate(
  formula = ~a1,
  data = predictor_dat,
  mesh = mesh,
  family = gaussian(),
  range = 0.3,
  phi = 0.2,
  sigma_O = 0.2,
  seed = 123,
  B = c(0.8, -0.4) # B0 = intercept, B1 = a1 slope
)

Visualize our simulated data:

ggplot(sim_dat, aes(X, Y, colour = observed)) +
  geom_point() +
  scale_color_viridis_c()

Fitting the model with marginal likelihood

First, fit a spatial random field GLMM with maximum likelihood:

fit <- sdmTMB(
  observed ~ a1,
  data = sim_dat,
  mesh = mesh,
  family = gaussian(),
  spatial = "on"
)
fit
#> Spatial model fit by ML ['sdmTMB']
#> Formula: observed ~ a1
#> Mesh: mesh (isotropic covariance)
#> Data: sim_dat
#> Family: gaussian(link = 'identity')
#>  
#>             coef.est coef.se
#> (Intercept)      0.8    0.06
#> a1              -0.4    0.01
#> 
#> Dispersion parameter: 0.20
#> Matérn range: 0.32
#> Spatial SD: 0.17
#> ML criterion at convergence: -65.939
#> 
#> See ?tidy.sdmTMB to extract these values as a data frame.

Adding priors

In that first model fit we did not use any priors. In that case, the priors are implied as uniform on the internal parameter space. However, sdmTMB provides the option of applying priors. Here we will show an example of applying a Normal(0, 5) (mean, SD) prior on the intercept and a Normal(0, 1) prior on the slope parameter. We could guess at the model matrix structure based on our formula, but we can verify it by looking at the internal model matrix from the previous fit (using do_fit = FALSE would save time if you didn’t want to fit it the first time).

head(fit$tmb_data$X_ij[[1]])
#>   (Intercept)          a1
#> 1           1 -0.60189285
#> 2           1 -0.99369859
#> 3           1  1.02678506
#> 4           1  0.75106130
#> 5           1 -1.50916654
#> 6           1 -0.09514745

Each column corresponds to the order of the b priors:

fit <- sdmTMB(
  observed ~ a1,
  data = sim_dat,
  mesh = mesh,
  family = gaussian(),
  spatial = "on",
  priors = sdmTMBpriors(
    # location = vector of means; scale = vector of standard deviations:
    b = normal(location = c(0, 0), scale = c(5, 2)),
  )
)
fit
#> Spatial model fit by ML ['sdmTMB']
#> Formula: observed ~ a1
#> Mesh: mesh (isotropic covariance)
#> Data: sim_dat
#> Family: gaussian(link = 'identity')
#>  
#>             coef.est coef.se
#> (Intercept)      0.8    0.06
#> a1              -0.4    0.01
#> 
#> Dispersion parameter: 0.20
#> Matérn range: 0.32
#> Spatial SD: 0.17
#> ML criterion at convergence: -62.846
#> 
#> See ?tidy.sdmTMB to extract these values as a data frame.

Fixing a spatial correlation parameter to improve convergence

Sometimes some of the spatial correlation parameters can be challenging to estimate with Stan. One option is to apply penalized complexity (PC) priors with sdmTMBpriors() to the Matérn parameters. Another option, which can also be used in conjunction with the priors, is to fix one or more parameters at their maximum likelihood estimate (MLE) values. Frequently, fixing the parameter ln_kappa can help convergence (e.g., Monnahan et al. 2021). This estimated parameter is transformed into the range estimate, so it controls the rate of spatial correlation decay.

Now we will rebuild the fitted object with fixed (‘mapped’) ln_kappa parameters using the update() function. We’ll use do_fit = FALSE to avoid actually fitting the updated model since it’s not necessary.

# grab the internal parameter list at estimated values:
pars <- sdmTMB::get_pars(fit)
# create a 'map' vector for TMB
# factor NA values cause TMB to fix or map the parameter at the starting value:
kappa_map <- factor(rep(NA, length(pars$ln_kappa)))

# rebuild model updating some elements:
fit_mle <- update(
  fit,
  control = sdmTMBcontrol(
    start = list(
      ln_kappa = pars$ln_kappa #<
    ),
    map = list(
      ln_kappa = kappa_map #<
    )
  ),
  do_fit = FALSE #<
)
#>  Initiating `ln_kappa` at specified starting value(s) of:
#> 2.173, 2.173
#>  Fixing or mirroring `ln_kappa`

Passing the model to tmbstan

Now we can pass the $tmb_obj element of our model to tmbstan::tmbstan(). We are only using 1000 iterations and 2 chains so this vignette builds quickly. In practice, you will likely want to use more (e.g., 2000 iterations, 4 chains).

fit_stan <- tmbstan::tmbstan(
  fit_mle$tmb_obj,
  iter = 1000, chains = 2,
  seed = 8217 # ensures repeatability
)

Sometimes you may need to adjust the sampler settings such as:

tmbstan::tmbstan(
  ...,
  control = list(adapt_delta = 0.9, max_treedepth = 12)
)

See the Details section in ?rstan::stan.

You can also ‘thin’ samples via the thin argument if working with model predictions becomes cumbersome given a large number of required samples.

We can look at the model:

fit_stan
#> Inference for Stan model: sdmTMB.
#> 2 chains, each with iter=1000; warmup=500; thin=1; 
#> post-warmup draws per chain=500, total post-warmup draws=1000.
#> 
#>               mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
#> b_j[1]        0.80    0.01 0.06   0.68   0.76   0.80   0.84   0.92    82 1.02
#> b_j[2]       -0.40    0.00 0.01  -0.41  -0.40  -0.40  -0.39  -0.38  1747 1.00
#> ln_tau_O     -1.67    0.01 0.14  -1.93  -1.76  -1.67  -1.57  -1.42   169 1.01
#> ln_phi       -1.63    0.00 0.03  -1.69  -1.65  -1.63  -1.61  -1.56  1360 1.00
#> omega_s[1]   -0.09    0.01 0.09  -0.26  -0.14  -0.09  -0.03   0.09   162 1.01
#> omega_s[2]   -0.06    0.01 0.09  -0.25  -0.12  -0.06   0.01   0.11   196 1.01
#> omega_s[3]    0.01    0.01 0.10  -0.17  -0.05   0.01   0.08   0.20   196 1.01
#> omega_s[4]   -0.21    0.01 0.09  -0.40  -0.27  -0.21  -0.15  -0.02   206 1.01
#> omega_s[5]   -0.34    0.01 0.10  -0.53  -0.41  -0.33  -0.27  -0.13   247 1.00
#> omega_s[6]   -0.08    0.01 0.10  -0.27  -0.15  -0.08   0.00   0.13   268 1.01
#> omega_s[7]   -0.02    0.01 0.09  -0.19  -0.08  -0.02   0.04   0.15   172 1.01
#> omega_s[8]   -0.23    0.01 0.10  -0.42  -0.29  -0.23  -0.16  -0.05   219 1.01
#> omega_s[9]   -0.32    0.01 0.09  -0.50  -0.38  -0.32  -0.26  -0.13   166 1.02
#> omega_s[10]   0.29    0.01 0.09   0.11   0.23   0.28   0.35   0.46   147 1.01
#> omega_s[11]  -0.16    0.01 0.09  -0.34  -0.22  -0.16  -0.09   0.02   208 1.01
#> omega_s[12]   0.00    0.01 0.10  -0.19  -0.07   0.01   0.07   0.21   205 1.01
#> omega_s[13]   0.19    0.01 0.09   0.02   0.13   0.19   0.25   0.36   166 1.01
#> omega_s[14]  -0.09    0.01 0.10  -0.29  -0.16  -0.09  -0.03   0.11   253 1.01
#> omega_s[15]   0.22    0.01 0.09   0.06   0.16   0.22   0.28   0.40   152 1.01
#> omega_s[16]  -0.02    0.01 0.10  -0.21  -0.08  -0.02   0.05   0.16   198 1.01
#> omega_s[17]  -0.15    0.01 0.09  -0.34  -0.21  -0.14  -0.08   0.04   198 1.01
#> omega_s[18]  -0.28    0.01 0.11  -0.50  -0.35  -0.28  -0.21  -0.09   270 1.00
#> omega_s[19]   0.00    0.01 0.10  -0.18  -0.06   0.00   0.07   0.19   211 1.01
#> omega_s[20]   0.03    0.01 0.08  -0.14  -0.03   0.02   0.09   0.19   180 1.01
#> omega_s[21]   0.08    0.01 0.10  -0.11   0.02   0.08   0.15   0.26   191 1.01
#> omega_s[22]  -0.01    0.01 0.10  -0.21  -0.08  -0.01   0.05   0.17   212 1.00
#> omega_s[23]   0.12    0.01 0.09  -0.04   0.06   0.12   0.18   0.30   156 1.01
#> omega_s[24]   0.20    0.01 0.10   0.00   0.13   0.19   0.26   0.40   256 1.00
#> omega_s[25]   0.08    0.01 0.09  -0.09   0.01   0.07   0.14   0.26   167 1.01
#> omega_s[26]  -0.01    0.01 0.10  -0.21  -0.08  -0.01   0.06   0.20   215 1.01
#> omega_s[27]  -0.11    0.01 0.09  -0.30  -0.18  -0.10  -0.05   0.06   216 1.00
#> omega_s[28]   0.12    0.01 0.10  -0.08   0.05   0.11   0.18   0.32   229 1.00
#> omega_s[29]   0.30    0.01 0.09   0.12   0.23   0.30   0.36   0.47   212 1.01
#> omega_s[30]  -0.04    0.01 0.09  -0.22  -0.10  -0.04   0.02   0.14   193 1.01
#> omega_s[31]   0.09    0.01 0.09  -0.08   0.04   0.09   0.15   0.26   154 1.01
#> omega_s[32]   0.05    0.01 0.12  -0.18  -0.02   0.05   0.14   0.28   244 1.01
#> omega_s[33]   0.08    0.01 0.09  -0.10   0.01   0.07   0.14   0.26   219 1.01
#> omega_s[34]   0.04    0.01 0.09  -0.13  -0.03   0.04   0.10   0.20   163 1.01
#> omega_s[35]   0.07    0.01 0.10  -0.10   0.00   0.07   0.14   0.26   230 1.01
#> omega_s[36]   0.14    0.01 0.09  -0.04   0.08   0.15   0.20   0.33   197 1.01
#> omega_s[37]   0.16    0.01 0.11  -0.06   0.09   0.16   0.24   0.39   275 1.00
#> omega_s[38]   0.12    0.01 0.10  -0.09   0.05   0.11   0.19   0.31   217 1.01
#> omega_s[39]  -0.22    0.01 0.10  -0.40  -0.28  -0.22  -0.15  -0.04   186 1.01
#> omega_s[40]  -0.03    0.01 0.09  -0.21  -0.09  -0.03   0.03   0.17   197 1.01
#> omega_s[41]   0.19    0.01 0.08   0.03   0.13   0.19   0.24   0.35   143 1.01
#> omega_s[42]   0.21    0.01 0.09   0.01   0.14   0.21   0.27   0.38   173 1.01
#> omega_s[43]   0.15    0.01 0.10  -0.04   0.07   0.15   0.21   0.35   221 1.01
#> omega_s[44]   0.14    0.01 0.10  -0.07   0.08   0.14   0.20   0.32   194 1.01
#> omega_s[45]   0.10    0.01 0.10  -0.10   0.02   0.10   0.16   0.30   248 1.00
#> omega_s[46]   0.06    0.01 0.10  -0.14  -0.01   0.06   0.12   0.26   187 1.01
#> omega_s[47]   0.31    0.01 0.09   0.13   0.24   0.31   0.37   0.50   166 1.01
#> omega_s[48]  -0.24    0.01 0.10  -0.44  -0.31  -0.24  -0.18  -0.06   196 1.01
#> omega_s[49]   0.10    0.01 0.10  -0.10   0.04   0.10   0.17   0.32   255 1.01
#> omega_s[50]  -0.09    0.01 0.09  -0.26  -0.14  -0.09  -0.03   0.08   178 1.01
#> omega_s[51]   0.25    0.01 0.11   0.04   0.17   0.25   0.32   0.48   247 1.01
#> omega_s[52]  -0.21    0.01 0.11  -0.43  -0.29  -0.21  -0.14  -0.01   251 1.01
#> omega_s[53]   0.04    0.01 0.10  -0.16  -0.03   0.04   0.10   0.22   213 1.00
#> omega_s[54]   0.03    0.01 0.10  -0.16  -0.04   0.03   0.10   0.22   214 1.00
#> omega_s[55]  -0.09    0.01 0.11  -0.30  -0.16  -0.09  -0.01   0.15   311 1.01
#> omega_s[56]  -0.42    0.01 0.10  -0.62  -0.48  -0.41  -0.35  -0.24   229 1.01
#> omega_s[57]   0.01    0.01 0.11  -0.20  -0.07   0.00   0.08   0.24   274 1.00
#> omega_s[58]  -0.21    0.01 0.10  -0.40  -0.28  -0.21  -0.14  -0.01   201 1.01
#> omega_s[59]   0.04    0.01 0.24  -0.43  -0.11   0.05   0.20   0.47   992 1.00
#> omega_s[60]  -0.23    0.01 0.28  -0.78  -0.41  -0.23  -0.03   0.27  1087 1.00
#> omega_s[61]  -0.26    0.01 0.24  -0.69  -0.41  -0.27  -0.10   0.20   662 1.00
#> omega_s[62]  -0.26    0.01 0.23  -0.72  -0.42  -0.26  -0.11   0.19   693 1.00
#> omega_s[63]  -0.27    0.01 0.23  -0.72  -0.43  -0.28  -0.12   0.18   807 1.00
#> omega_s[64]   0.08    0.01 0.22  -0.37  -0.07   0.09   0.23   0.51   866 1.00
#> omega_s[65]   0.17    0.01 0.22  -0.29   0.02   0.17   0.32   0.57   693 1.00
#> omega_s[66]   0.16    0.01 0.22  -0.27   0.02   0.16   0.31   0.60   525 1.00
#> omega_s[67]  -0.03    0.01 0.22  -0.49  -0.17  -0.01   0.13   0.42  1154 1.00
#> omega_s[68]  -0.02    0.01 0.25  -0.50  -0.17  -0.02   0.15   0.47  1152 1.00
#> omega_s[69]   0.00    0.01 0.22  -0.46  -0.15   0.00   0.15   0.43   458 1.01
#> omega_s[70]   0.01    0.01 0.21  -0.42  -0.13   0.01   0.15   0.40   431 1.00
#> omega_s[71]   0.17    0.01 0.24  -0.30   0.01   0.17   0.33   0.67   884 1.00
#> omega_s[72]  -0.12    0.01 0.24  -0.62  -0.29  -0.11   0.04   0.35  1064 1.00
#> omega_s[73]  -0.14    0.01 0.22  -0.58  -0.28  -0.15   0.01   0.30   600 1.00
#> omega_s[74]  -0.14    0.01 0.22  -0.56  -0.29  -0.14   0.01   0.26   624 1.00
#> omega_s[75]  -0.38    0.01 0.20  -0.76  -0.51  -0.38  -0.25   0.00   654 1.00
#> omega_s[76]   0.10    0.01 0.21  -0.31  -0.04   0.10   0.23   0.51   848 1.00
#> omega_s[77]   0.09    0.01 0.22  -0.34  -0.04   0.09   0.23   0.52  1053 1.00
#> omega_s[78]  -0.06    0.01 0.21  -0.47  -0.19  -0.06   0.08   0.35  1693 1.00
#> omega_s[79]   0.06    0.01 0.16  -0.27  -0.05   0.06   0.17   0.37   622 1.00
#> omega_s[80]  -0.18    0.01 0.23  -0.65  -0.33  -0.18  -0.02   0.27   632 1.00
#> omega_s[81]  -0.07    0.01 0.20  -0.45  -0.20  -0.07   0.05   0.33  1165 1.00
#> omega_s[82]  -0.14    0.01 0.21  -0.55  -0.28  -0.14   0.00   0.29   810 1.00
#> omega_s[83]  -0.02    0.01 0.23  -0.47  -0.18  -0.01   0.14   0.44  1096 1.00
#> omega_s[84]   0.12    0.01 0.20  -0.26  -0.02   0.12   0.25   0.53   952 1.00
#> omega_s[85]  -0.29    0.01 0.21  -0.68  -0.44  -0.30  -0.15   0.11  1111 1.00
#> lp__        136.06    0.93 9.26 118.11 129.87 136.38 142.46 152.88   100 1.02
#> 
#> Samples were drawn using NUTS(diag_e) at Wed Jul  3 18:50:49 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split chains (at 
#> convergence, Rhat=1).

The Rhat values look reasonable (< 1.05). The n_eff (number of effective samples) values mostly look reasonable (> 100) for inference about the mean for all parameters except the intercept (b_j[1]). Furthermore, we can see correlation in the MCMC samples for b_j[1]. We could try running for more iterations and chains and/or placing priors on this and other parameters as described below (highly recommended).

Now we can use various functions to visualize the posterior:

plot(fit_stan)
#> 'pars' not specified. Showing first 10 parameters by default.
#> ci_level: 0.8 (80% intervals)
#> outer_level: 0.95 (95% intervals)

pars_plot <- c("b_j[1]", "b_j[2]", "ln_tau_O", "omega_s[1]")

bayesplot::mcmc_trace(fit_stan, pars = pars_plot)

bayesplot::mcmc_pairs(fit_stan, pars = pars_plot)

Posterior predictive checks

We can perform posterior predictive checks to assess whether our model can generate predictive data that are consistent with the observations. For this, we can make use of simulate.sdmTMB() while passing in our Stan model. simulate.sdmTMB() will take draws from the joint parameter posterior and add observation error. We need to ensure nsim is less than or equal to the total number of post-warmup samples.

set.seed(19292)
samps <- sdmTMBextra::extract_mcmc(fit_stan)
s <- simulate(fit_mle, mcmc_samples = samps, nsim = 50)
bayesplot::pp_check(
  sim_dat$observed,
  yrep = t(s),
  fun = bayesplot::ppc_dens_overlay
)

See ?bayesplot::pp_check. The solid line represents the density of the observed data and the light blue lines represent the density of 50 posterior predictive simulations. In this case, the simulated data seem consistent with the observed data.

Plotting predictions

We can make predictions with our Bayesian model by supplying the posterior samples to the mcmc_samples argument in predict.sdmTMB().

pred <- predict(fit_mle, mcmc_samples = samps)

The output is a matrix where each row corresponds to a row of predicted data and each column corresponds to a sample.

dim(pred)
#> [1]  500 1000

We can summarize these draws in various ways to visualize them:

sim_dat$post_mean <- apply(pred, 1, mean)
sim_dat$post_sd <- apply(pred, 1, sd)

ggplot(sim_dat, aes(X, Y, colour = post_mean)) +
  geom_point() +
  scale_color_viridis_c()


ggplot(sim_dat, aes(X, Y, colour = post_sd)) +
  geom_point() +
  scale_color_viridis_c()

Or predict on a grid for a given value of a1:

nd <- expand.grid(
  X = seq(0, 1, length.out = 70),
  Y = seq(0, 1, length.out = 70),
  a1 = 0
)
pred <- predict(fit_mle, newdata = nd, mcmc_samples = samps)

nd$post_mean <- apply(pred, 1, mean)
nd$post_sd <- apply(pred, 1, sd)

ggplot(nd, aes(X, Y, fill = post_mean)) +
  geom_raster() +
  scale_fill_viridis_c() +
  coord_fixed()


ggplot(nd, aes(X, Y, fill = post_sd)) +
  geom_raster() +
  scale_fill_viridis_c() +
  coord_fixed()

Extracting parameter posterior samples

We can extract posterior samples with rstan::extract(),

post <- rstan::extract(fit_stan)

The result is a list where each element corresponds to a parameter or set of parameters:

names(post)
#> [1] "b_j"      "ln_tau_O" "ln_phi"   "omega_s"  "lp__"
hist(post$b_j[, 1])

As an example of calculating a derived parameter, here we will calculate the marginal spatial random field standard deviation:

ln_kappa <- get_pars(fit_mle)$ln_kappa[1] # 2 elements since 2nd would be for spatiotemporal
ln_tau_O <- post$ln_tau_O
sigma_O <- 1 / sqrt(4 * pi * exp(2 * ln_tau_O + 2 * ln_kappa))
hist(sigma_O)

Extracting the posterior of other predicted elements

By default predict.sdmTMB() returns the overall prediction in link space when a tmbstan model is passed in. If instead we want some other element that we might find in the usual data frame returned by predict.sdmTMB() when applied to a regular sdmTMB model, we can specify that through the sims_var argument.

For example, let’s extract the spatial random field values "omega_s". Other options are documented in ?predict.sdmTMB().

fit_pred <- predict(
  fit_mle,
  newdata = nd,
  mcmc_samples = samps,
  sims_var = "omega_s" #<
)

nd$spatial_rf_mean <- apply(fit_pred, 1, mean)
nd$spatial_rf_sd <- apply(fit_pred, 1, sd)

ggplot(nd, aes(X, Y, fill = spatial_rf_mean)) +
  geom_raster() +
  scale_fill_gradient2() +
  coord_fixed()


ggplot(nd, aes(X, Y, fill = spatial_rf_sd)) +
  geom_raster() +
  scale_fill_viridis_c() +
  coord_fixed()

References

Monnahan, C. & Kristensen, K. (2018). No-U-turn sampling for fast bayesian inference in ADMB and TMB: Introducing the adnuts and tmbstan R packages. PloS one, 13.
Monnahan, C.C., Thorson, J.T., Kotwicki, S., Lauffenburger, N., Ianelli, J.N. & Punt, A.E. (2021). Incorporating vertical distribution in index standardization accounts for spatiotemporal availability to acoustic and bottom trawl gear for semi-pelagic species (O.R. Godo, Ed.). ICES Journal of Marine Science, 78, 1826–1839.
Stan Development Team. (2021). RStan: The R interface to Stan. Retrieved from https://mc-stan.org/