If the code in this vignette has not been evaluated, a rendered version is available on the documentation site under ‘Articles’.
library(ggplot2)
library(dplyr)
library(sdmTMB)
library(rstan) # for plot() method
options(mc.cores = parallel::detectCores()) # use rstan parallel processing
Bayesian estimation is possible with sdmTMB by passing fitted models
to tmbstan::tmbstan()
(Monnahan
& Kristensen 2018). All sampling is then done using Stan
(Stan Development Team 2021), and output
is returned as a stanfit
object.
Why might you want to pass an sdmTMB model to Stan?
- to obtain probabilistic inference on parameters
- to avoid the Laplace approximation on the random effects
- to robustly quantify uncertainty on derived quantities not already calculated in the model
- in some cases, models that struggle to converge with maximum likelihood can be adequately sampled with MCMC given carefully chosen priors (e.g., Monnahan et al. 2021)
Simulating data
Here we will demonstrate using a simulated dataset.
set.seed(123)
predictor_dat <- data.frame(
X = runif(500), Y = runif(500),
a1 = rnorm(500)
)
mesh <- make_mesh(predictor_dat, xy_cols = c("X", "Y"), cutoff = 0.1)
# plot(mesh)
# mesh$mesh$n
sim_dat <- sdmTMB_simulate(
formula = ~a1,
data = predictor_dat,
mesh = mesh,
family = gaussian(),
range = 0.3,
phi = 0.2,
sigma_O = 0.2,
seed = 123,
B = c(0.8, -0.4) # B0 = intercept, B1 = a1 slope
)
Visualize our simulated data:
ggplot(sim_dat, aes(X, Y, colour = observed)) +
geom_point() +
scale_color_viridis_c()
Fitting the model with marginal likelihood
First, fit a spatial random field GLMM with maximum likelihood:
fit <- sdmTMB(
observed ~ a1,
data = sim_dat,
mesh = mesh,
family = gaussian(),
spatial = "on"
)
fit
#> Spatial model fit by ML ['sdmTMB']
#> Formula: observed ~ a1
#> Mesh: mesh (isotropic covariance)
#> Data: sim_dat
#> Family: gaussian(link = 'identity')
#>
#> coef.est coef.se
#> (Intercept) 0.8 0.06
#> a1 -0.4 0.01
#>
#> Dispersion parameter: 0.20
#> Matérn range: 0.32
#> Spatial SD: 0.17
#> ML criterion at convergence: -65.939
#>
#> See ?tidy.sdmTMB to extract these values as a data frame.
Adding priors
In that first model fit we did not use any priors. In that case, the
priors are implied as uniform on the internal parameter space. However,
sdmTMB provides the option of applying priors. Here we will show an
example of applying a Normal(0, 5) (mean, SD) prior on the intercept and
a Normal(0, 1) prior on the slope parameter. We could guess at the model
matrix structure based on our formula, but we can verify it by looking
at the internal model matrix from the previous fit (using
do_fit = FALSE
would save time if you didn’t want to fit it
the first time).
head(fit$tmb_data$X_ij[[1]])
#> (Intercept) a1
#> 1 1 -0.60189285
#> 2 1 -0.99369859
#> 3 1 1.02678506
#> 4 1 0.75106130
#> 5 1 -1.50916654
#> 6 1 -0.09514745
Each column corresponds to the order of the b
priors:
fit <- sdmTMB(
observed ~ a1,
data = sim_dat,
mesh = mesh,
family = gaussian(),
spatial = "on",
priors = sdmTMBpriors(
# location = vector of means; scale = vector of standard deviations:
b = normal(location = c(0, 0), scale = c(5, 2)),
)
)
fit
#> Spatial model fit by ML ['sdmTMB']
#> Formula: observed ~ a1
#> Mesh: mesh (isotropic covariance)
#> Data: sim_dat
#> Family: gaussian(link = 'identity')
#>
#> coef.est coef.se
#> (Intercept) 0.8 0.06
#> a1 -0.4 0.01
#>
#> Dispersion parameter: 0.20
#> Matérn range: 0.32
#> Spatial SD: 0.17
#> ML criterion at convergence: -62.846
#>
#> See ?tidy.sdmTMB to extract these values as a data frame.
Fixing a spatial correlation parameter to improve convergence
Sometimes some of the spatial correlation parameters can be
challenging to estimate with Stan. One option is to apply penalized
complexity (PC) priors with sdmTMBpriors()
to the Matérn
parameters. Another option, which can also be used in conjunction with
the priors, is to fix one or more parameters at their maximum likelihood
estimate (MLE) values. Frequently, fixing the parameter
ln_kappa
can help convergence (e.g.,
Monnahan et al. 2021). This estimated parameter is
transformed into the range estimate, so it controls the rate of spatial
correlation decay.
Now we will rebuild the fitted object with fixed (‘mapped’)
ln_kappa
parameters using the update()
function. We’ll use do_fit = FALSE
to avoid actually
fitting the updated model since it’s not necessary.
# grab the internal parameter list at estimated values:
pars <- sdmTMB::get_pars(fit)
# create a 'map' vector for TMB
# factor NA values cause TMB to fix or map the parameter at the starting value:
kappa_map <- factor(rep(NA, length(pars$ln_kappa)))
# rebuild model updating some elements:
fit_mle <- update(
fit,
control = sdmTMBcontrol(
start = list(
ln_kappa = pars$ln_kappa #<
),
map = list(
ln_kappa = kappa_map #<
)
),
do_fit = FALSE #<
)
#> ℹ Initiating `ln_kappa` at specified starting value(s) of:
#> 2.173, 2.173
#> ℹ Fixing or mirroring `ln_kappa`
Passing the model to tmbstan
Now we can pass the $tmb_obj
element of our model to
tmbstan::tmbstan()
. We are only using 1000 iterations and 2
chains so this vignette builds quickly. In practice, you will likely
want to use more (e.g., 2000 iterations, 4 chains).
fit_stan <- tmbstan::tmbstan(
fit_mle$tmb_obj,
iter = 1000, chains = 2,
seed = 8217 # ensures repeatability
)
Sometimes you may need to adjust the sampler settings such as:
See the Details section in ?rstan::stan
.
You can also ‘thin’ samples via the thin
argument if
working with model predictions becomes cumbersome given a large number
of required samples.
We can look at the model:
fit_stan
#> Inference for Stan model: sdmTMB.
#> 2 chains, each with iter=1000; warmup=500; thin=1;
#> post-warmup draws per chain=500, total post-warmup draws=1000.
#>
#> mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
#> b_j[1] 0.80 0.01 0.06 0.68 0.76 0.80 0.84 0.92 82 1.02
#> b_j[2] -0.40 0.00 0.01 -0.41 -0.40 -0.40 -0.39 -0.38 1747 1.00
#> ln_tau_O -1.67 0.01 0.14 -1.93 -1.76 -1.67 -1.57 -1.42 169 1.01
#> ln_phi -1.63 0.00 0.03 -1.69 -1.65 -1.63 -1.61 -1.56 1360 1.00
#> omega_s[1] -0.09 0.01 0.09 -0.26 -0.14 -0.09 -0.03 0.09 162 1.01
#> omega_s[2] -0.06 0.01 0.09 -0.25 -0.12 -0.06 0.01 0.11 196 1.01
#> omega_s[3] 0.01 0.01 0.10 -0.17 -0.05 0.01 0.08 0.20 196 1.01
#> omega_s[4] -0.21 0.01 0.09 -0.40 -0.27 -0.21 -0.15 -0.02 206 1.01
#> omega_s[5] -0.34 0.01 0.10 -0.53 -0.41 -0.33 -0.27 -0.13 247 1.00
#> omega_s[6] -0.08 0.01 0.10 -0.27 -0.15 -0.08 0.00 0.13 268 1.01
#> omega_s[7] -0.02 0.01 0.09 -0.19 -0.08 -0.02 0.04 0.15 172 1.01
#> omega_s[8] -0.23 0.01 0.10 -0.42 -0.29 -0.23 -0.16 -0.05 219 1.01
#> omega_s[9] -0.32 0.01 0.09 -0.50 -0.38 -0.32 -0.26 -0.13 166 1.02
#> omega_s[10] 0.29 0.01 0.09 0.11 0.23 0.28 0.35 0.46 147 1.01
#> omega_s[11] -0.16 0.01 0.09 -0.34 -0.22 -0.16 -0.09 0.02 208 1.01
#> omega_s[12] 0.00 0.01 0.10 -0.19 -0.07 0.01 0.07 0.21 205 1.01
#> omega_s[13] 0.19 0.01 0.09 0.02 0.13 0.19 0.25 0.36 166 1.01
#> omega_s[14] -0.09 0.01 0.10 -0.29 -0.16 -0.09 -0.03 0.11 253 1.01
#> omega_s[15] 0.22 0.01 0.09 0.06 0.16 0.22 0.28 0.40 152 1.01
#> omega_s[16] -0.02 0.01 0.10 -0.21 -0.08 -0.02 0.05 0.16 198 1.01
#> omega_s[17] -0.15 0.01 0.09 -0.34 -0.21 -0.14 -0.08 0.04 198 1.01
#> omega_s[18] -0.28 0.01 0.11 -0.50 -0.35 -0.28 -0.21 -0.09 270 1.00
#> omega_s[19] 0.00 0.01 0.10 -0.18 -0.06 0.00 0.07 0.19 211 1.01
#> omega_s[20] 0.03 0.01 0.08 -0.14 -0.03 0.02 0.09 0.19 180 1.01
#> omega_s[21] 0.08 0.01 0.10 -0.11 0.02 0.08 0.15 0.26 191 1.01
#> omega_s[22] -0.01 0.01 0.10 -0.21 -0.08 -0.01 0.05 0.17 212 1.00
#> omega_s[23] 0.12 0.01 0.09 -0.04 0.06 0.12 0.18 0.30 156 1.01
#> omega_s[24] 0.20 0.01 0.10 0.00 0.13 0.19 0.26 0.40 256 1.00
#> omega_s[25] 0.08 0.01 0.09 -0.09 0.01 0.07 0.14 0.26 167 1.01
#> omega_s[26] -0.01 0.01 0.10 -0.21 -0.08 -0.01 0.06 0.20 215 1.01
#> omega_s[27] -0.11 0.01 0.09 -0.30 -0.18 -0.10 -0.05 0.06 216 1.00
#> omega_s[28] 0.12 0.01 0.10 -0.08 0.05 0.11 0.18 0.32 229 1.00
#> omega_s[29] 0.30 0.01 0.09 0.12 0.23 0.30 0.36 0.47 212 1.01
#> omega_s[30] -0.04 0.01 0.09 -0.22 -0.10 -0.04 0.02 0.14 193 1.01
#> omega_s[31] 0.09 0.01 0.09 -0.08 0.04 0.09 0.15 0.26 154 1.01
#> omega_s[32] 0.05 0.01 0.12 -0.18 -0.02 0.05 0.14 0.28 244 1.01
#> omega_s[33] 0.08 0.01 0.09 -0.10 0.01 0.07 0.14 0.26 219 1.01
#> omega_s[34] 0.04 0.01 0.09 -0.13 -0.03 0.04 0.10 0.20 163 1.01
#> omega_s[35] 0.07 0.01 0.10 -0.10 0.00 0.07 0.14 0.26 230 1.01
#> omega_s[36] 0.14 0.01 0.09 -0.04 0.08 0.15 0.20 0.33 197 1.01
#> omega_s[37] 0.16 0.01 0.11 -0.06 0.09 0.16 0.24 0.39 275 1.00
#> omega_s[38] 0.12 0.01 0.10 -0.09 0.05 0.11 0.19 0.31 217 1.01
#> omega_s[39] -0.22 0.01 0.10 -0.40 -0.28 -0.22 -0.15 -0.04 186 1.01
#> omega_s[40] -0.03 0.01 0.09 -0.21 -0.09 -0.03 0.03 0.17 197 1.01
#> omega_s[41] 0.19 0.01 0.08 0.03 0.13 0.19 0.24 0.35 143 1.01
#> omega_s[42] 0.21 0.01 0.09 0.01 0.14 0.21 0.27 0.38 173 1.01
#> omega_s[43] 0.15 0.01 0.10 -0.04 0.07 0.15 0.21 0.35 221 1.01
#> omega_s[44] 0.14 0.01 0.10 -0.07 0.08 0.14 0.20 0.32 194 1.01
#> omega_s[45] 0.10 0.01 0.10 -0.10 0.02 0.10 0.16 0.30 248 1.00
#> omega_s[46] 0.06 0.01 0.10 -0.14 -0.01 0.06 0.12 0.26 187 1.01
#> omega_s[47] 0.31 0.01 0.09 0.13 0.24 0.31 0.37 0.50 166 1.01
#> omega_s[48] -0.24 0.01 0.10 -0.44 -0.31 -0.24 -0.18 -0.06 196 1.01
#> omega_s[49] 0.10 0.01 0.10 -0.10 0.04 0.10 0.17 0.32 255 1.01
#> omega_s[50] -0.09 0.01 0.09 -0.26 -0.14 -0.09 -0.03 0.08 178 1.01
#> omega_s[51] 0.25 0.01 0.11 0.04 0.17 0.25 0.32 0.48 247 1.01
#> omega_s[52] -0.21 0.01 0.11 -0.43 -0.29 -0.21 -0.14 -0.01 251 1.01
#> omega_s[53] 0.04 0.01 0.10 -0.16 -0.03 0.04 0.10 0.22 213 1.00
#> omega_s[54] 0.03 0.01 0.10 -0.16 -0.04 0.03 0.10 0.22 214 1.00
#> omega_s[55] -0.09 0.01 0.11 -0.30 -0.16 -0.09 -0.01 0.15 311 1.01
#> omega_s[56] -0.42 0.01 0.10 -0.62 -0.48 -0.41 -0.35 -0.24 229 1.01
#> omega_s[57] 0.01 0.01 0.11 -0.20 -0.07 0.00 0.08 0.24 274 1.00
#> omega_s[58] -0.21 0.01 0.10 -0.40 -0.28 -0.21 -0.14 -0.01 201 1.01
#> omega_s[59] 0.04 0.01 0.24 -0.43 -0.11 0.05 0.20 0.47 992 1.00
#> omega_s[60] -0.23 0.01 0.28 -0.78 -0.41 -0.23 -0.03 0.27 1087 1.00
#> omega_s[61] -0.26 0.01 0.24 -0.69 -0.41 -0.27 -0.10 0.20 662 1.00
#> omega_s[62] -0.26 0.01 0.23 -0.72 -0.42 -0.26 -0.11 0.19 693 1.00
#> omega_s[63] -0.27 0.01 0.23 -0.72 -0.43 -0.28 -0.12 0.18 807 1.00
#> omega_s[64] 0.08 0.01 0.22 -0.37 -0.07 0.09 0.23 0.51 866 1.00
#> omega_s[65] 0.17 0.01 0.22 -0.29 0.02 0.17 0.32 0.57 693 1.00
#> omega_s[66] 0.16 0.01 0.22 -0.27 0.02 0.16 0.31 0.60 525 1.00
#> omega_s[67] -0.03 0.01 0.22 -0.49 -0.17 -0.01 0.13 0.42 1154 1.00
#> omega_s[68] -0.02 0.01 0.25 -0.50 -0.17 -0.02 0.15 0.47 1152 1.00
#> omega_s[69] 0.00 0.01 0.22 -0.46 -0.15 0.00 0.15 0.43 458 1.01
#> omega_s[70] 0.01 0.01 0.21 -0.42 -0.13 0.01 0.15 0.40 431 1.00
#> omega_s[71] 0.17 0.01 0.24 -0.30 0.01 0.17 0.33 0.67 884 1.00
#> omega_s[72] -0.12 0.01 0.24 -0.62 -0.29 -0.11 0.04 0.35 1064 1.00
#> omega_s[73] -0.14 0.01 0.22 -0.58 -0.28 -0.15 0.01 0.30 600 1.00
#> omega_s[74] -0.14 0.01 0.22 -0.56 -0.29 -0.14 0.01 0.26 624 1.00
#> omega_s[75] -0.38 0.01 0.20 -0.76 -0.51 -0.38 -0.25 0.00 654 1.00
#> omega_s[76] 0.10 0.01 0.21 -0.31 -0.04 0.10 0.23 0.51 848 1.00
#> omega_s[77] 0.09 0.01 0.22 -0.34 -0.04 0.09 0.23 0.52 1053 1.00
#> omega_s[78] -0.06 0.01 0.21 -0.47 -0.19 -0.06 0.08 0.35 1693 1.00
#> omega_s[79] 0.06 0.01 0.16 -0.27 -0.05 0.06 0.17 0.37 622 1.00
#> omega_s[80] -0.18 0.01 0.23 -0.65 -0.33 -0.18 -0.02 0.27 632 1.00
#> omega_s[81] -0.07 0.01 0.20 -0.45 -0.20 -0.07 0.05 0.33 1165 1.00
#> omega_s[82] -0.14 0.01 0.21 -0.55 -0.28 -0.14 0.00 0.29 810 1.00
#> omega_s[83] -0.02 0.01 0.23 -0.47 -0.18 -0.01 0.14 0.44 1096 1.00
#> omega_s[84] 0.12 0.01 0.20 -0.26 -0.02 0.12 0.25 0.53 952 1.00
#> omega_s[85] -0.29 0.01 0.21 -0.68 -0.44 -0.30 -0.15 0.11 1111 1.00
#> lp__ 136.06 0.93 9.26 118.11 129.87 136.38 142.46 152.88 100 1.02
#>
#> Samples were drawn using NUTS(diag_e) at Wed Jul 3 18:50:49 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split chains (at
#> convergence, Rhat=1).
The Rhat
values look reasonable (< 1.05). The
n_eff
(number of effective samples) values mostly look
reasonable (> 100) for inference about the mean for all parameters
except the intercept (b_j[1]
). Furthermore, we can see
correlation in the MCMC samples for b_j[1]
. We could try
running for more iterations and chains and/or placing priors on this and
other parameters as described below (highly recommended).
Now we can use various functions to visualize the posterior:
plot(fit_stan)
#> 'pars' not specified. Showing first 10 parameters by default.
#> ci_level: 0.8 (80% intervals)
#> outer_level: 0.95 (95% intervals)
pars_plot <- c("b_j[1]", "b_j[2]", "ln_tau_O", "omega_s[1]")
bayesplot::mcmc_trace(fit_stan, pars = pars_plot)
bayesplot::mcmc_pairs(fit_stan, pars = pars_plot)
Posterior predictive checks
We can perform posterior predictive checks to assess whether our
model can generate predictive data that are consistent with the
observations. For this, we can make use of
simulate.sdmTMB()
while passing in our Stan model.
simulate.sdmTMB()
will take draws from the joint parameter
posterior and add observation error. We need to ensure nsim
is less than or equal to the total number of post-warmup samples.
set.seed(19292)
samps <- sdmTMBextra::extract_mcmc(fit_stan)
s <- simulate(fit_mle, mcmc_samples = samps, nsim = 50)
bayesplot::pp_check(
sim_dat$observed,
yrep = t(s),
fun = bayesplot::ppc_dens_overlay
)
See ?bayesplot::pp_check
. The solid line represents the
density of the observed data and the light blue lines represent the
density of 50 posterior predictive simulations. In this case, the
simulated data seem consistent with the observed data.
Plotting predictions
We can make predictions with our Bayesian model by supplying the
posterior samples to the mcmc_samples
argument in
predict.sdmTMB()
.
pred <- predict(fit_mle, mcmc_samples = samps)
The output is a matrix where each row corresponds to a row of predicted data and each column corresponds to a sample.
dim(pred)
#> [1] 500 1000
We can summarize these draws in various ways to visualize them:
sim_dat$post_mean <- apply(pred, 1, mean)
sim_dat$post_sd <- apply(pred, 1, sd)
ggplot(sim_dat, aes(X, Y, colour = post_mean)) +
geom_point() +
scale_color_viridis_c()
ggplot(sim_dat, aes(X, Y, colour = post_sd)) +
geom_point() +
scale_color_viridis_c()
Or predict on a grid for a given value of a1
:
nd <- expand.grid(
X = seq(0, 1, length.out = 70),
Y = seq(0, 1, length.out = 70),
a1 = 0
)
pred <- predict(fit_mle, newdata = nd, mcmc_samples = samps)
nd$post_mean <- apply(pred, 1, mean)
nd$post_sd <- apply(pred, 1, sd)
ggplot(nd, aes(X, Y, fill = post_mean)) +
geom_raster() +
scale_fill_viridis_c() +
coord_fixed()
ggplot(nd, aes(X, Y, fill = post_sd)) +
geom_raster() +
scale_fill_viridis_c() +
coord_fixed()
Extracting parameter posterior samples
We can extract posterior samples with
rstan::extract()
,
post <- rstan::extract(fit_stan)
The result is a list where each element corresponds to a parameter or set of parameters:
names(post)
#> [1] "b_j" "ln_tau_O" "ln_phi" "omega_s" "lp__"
hist(post$b_j[, 1])
As an example of calculating a derived parameter, here we will calculate the marginal spatial random field standard deviation:
Extracting the posterior of other predicted elements
By default predict.sdmTMB()
returns the overall
prediction in link space when a tmbstan model is passed in. If instead
we want some other element that we might find in the usual data frame
returned by predict.sdmTMB()
when applied to a regular
sdmTMB model, we can specify that through the sims_var
argument.
For example, let’s extract the spatial random field values
"omega_s"
. Other options are documented in
?predict.sdmTMB()
.
fit_pred <- predict(
fit_mle,
newdata = nd,
mcmc_samples = samps,
sims_var = "omega_s" #<
)
nd$spatial_rf_mean <- apply(fit_pred, 1, mean)
nd$spatial_rf_sd <- apply(fit_pred, 1, sd)
ggplot(nd, aes(X, Y, fill = spatial_rf_mean)) +
geom_raster() +
scale_fill_gradient2() +
coord_fixed()
ggplot(nd, aes(X, Y, fill = spatial_rf_sd)) +
geom_raster() +
scale_fill_viridis_c() +
coord_fixed()