March 22, 2022

Fitting a Basic SIR Model in Stan

Fitting a Basic SIR Model in Stan
Eric Novik
CEO, Co-Founder
Fitting a Basic SIR Model in Stan

Today, it seems like everyone is an epidemiologist. I am definitely not an epidemiologist but I did want to learn the basics of the popular SIR (Susceptible, Infected, Recovered) models. My guide is “Contemporary statistical inference for infectious disease models using Stan” by Chatzilena et al. [1], which does a great job at describing the model and in the spirit of reproducibility provides both Stan and R code. In this blog post, I will test a slightly modified version of the model on a simulated dataset and implement out-of-sample prediction logic in Stan.

SIR model

The SIR model can be formulated using a three-state ODE system as follows.

$$\begin{align*}
\frac{\text{d}S}{\text{d}t} &= -\beta \frac{I(t)}{N} S(t) \\
\frac{\text{d}I}{\text{d}t} &= \beta \frac{I(t)}{N} S(t) - \gamma I(t) \\
\frac{\text{d}R}{\text{d}t} &= \gamma I(t)
\end{align*}$$

$S(t)$ is the number of susceptible individuals, $I(t)$ is infected, and $R(t)$ is recovered. $N$ is the total population which is the sum of the three groups, $\beta$ represents the transmission rate, $\gamma$ the recovery rate, and the by now infamous $R_0$, the basic reproduction number is the ratio $\beta / \gamma$. When $R_0 > 1$, in this model, the infection starts spreading in the population.

In case you are a pharmacometrician, this model has a multi-compartment flavor to it: the disease enters the susceptible population, reduces its size by infecting a fraction of people, and the infected fraction is reduced by the fraction of individuals who recover from the disease.

Setting up a simulation in R

It is easy to set up a forward simulation from this model in R thanks to the deSolve package that implements a number of ODE solvers.

1library(ggplot2)
2library(deSolve)
3library(tidyr)
4theme_set(theme_minimal())
5
6# S: susceptible, I: infected, R: recovered
7SIR <- function(t, state, pars) {
8  with(as.list(c(state, pars)), {
9    N <- S + I + R
10    dS_dt <- -b * I/N * S
11    dI_dt <-  b * I/N * S - g * I
12    dR_dt <-  g * I
13    return(list(c(dS_dt, dI_dt, dR_dt)))
14  })
15}

Once the ODE system is set up, we pick the values for the parameters, initial conditions, and times points and use the solver. I took the population size and estimated parameter values from the paper:

In 1978, there was a report to the British Medical Journal for an influenza outbreak in a boarding school in the north of England. There were 763 male students which were mostly full boarders and 512 of them became ill.

1n <- 100
2n_pop <- 763
3pars <- c(b = 2, g = 0.5)
4state <- c(S = n_pop - 1, I = 1, R = 0)
5times <- seq(1, 15, length = n)
6sol <-
7  ode(
8    y = state,
9    times = times,
10    func = SIR,
11    parms = pars,
12    method = "ode45"
13  )
14sol <- as.data.frame(sol)
15sol_long <- sol |>
16  tidyr::pivot_longer(-time, names_to = "state", values_to = "value")
17
18sol_long |>
19  ggplot(aes(time, value, color = state)) +
20  geom_line() +
21  guides(color = guide_legend(title = NULL)) +
22  scale_color_discrete(labels = c("Infected", "Recovered", "Susceptible")) +
23  xlab("Days") + ylab("Number of people") +
24  ggtitle("Basic SIR model", subtitle = "3-state ODE with beta = 2,
25          gamma = 0.5, R0 = 4")

If you run the above code, you should see the following plot:

Inferring the parameters in Stan

Forward simulations are easy and fun but science works in the other direction. We observe noisy data, which in our case are diagnosed infections, and from these data and the stated model, we must learn all the plausible values of the parameters. To simulate the noisy data generating process, we will assume that our observed infections y are Poison distributed (not a great assumption, but will do for our purposes) with the rate $\lambda(t)$ equal to $I(t)$ from the solver.

set.seed(1234)
y <- rpois(n, sol$I)

The following Stan code is a slightly modified version of the one found in the paper. I changed the priors a bit, put some additional constraints on the parameters, and made the ODE code a little more readable. You will also notice that unlike in the forward simulation, the ODE in the Stan program is parameterized without the $1/N$ term. This is to ensure that the solution is on the proportion scale, which is converted back to the absolute scale when computing $\lambda(t)$ and later in the make_df() R function. Code listing can be accessed here [2].

1functions {
2  vector SIR(real t,vector y, array[] real theta) {
3    real S     = y[1];
4    real I     = y[2];
5    real R     = y[3];
6    
7    real beta  = theta[1];
8    real gamma = theta[2];
9    
10    vector[3] dydt;
11    
12    dydt[1] = -beta * S * I;
13    dydt[2] =  beta * S * I - gamma * I;
14    dydt[3] =  gamma * I;
15    
16    return dydt;
17  }
18}
19
20data {
21  int<lower=1> n_obs;   // number of observation times
22  int<lower=1> n_pop;   // total population size
23  array[n_obs] int y;   // observed infected counts
24  real t0;              // initial time (e.g. 0)
25  array[n_obs] real ts; // times at which y was observed
26}
27
28parameters {
29  array[2] real<lower=0> theta; // {beta, gamma}
30  real<lower=0,upper=1> S0;     // initial susceptible fraction
31}
32
33transformed parameters {
34  vector[3] y_init;             // [S(0), I(0), R(0)]
35  array[n_obs] vector[3] y_hat; // solution at each ts
36  array[n_obs] real lambda;     // Poisson rates
37
38  // set initial conditions
39  y_init[1] = S0;
40  y_init[2] = 1 - S0;
41  y_init[3] = 0;
42
43  y_hat = ode_rk45(SIR, y_init, t0, ts, theta);
44
45  // convert infected fraction → expected counts
46  for (i in 1:n_obs)
47    lambda[i] = y_hat[i, 2] * n_pop;
48}
49
50model {
51  theta ~ lognormal(0, 1);
52  S0    ~ beta(1, 1);
53  y     ~ poisson(lambda);
54}
55
56generated quantities {
57  real R_0 = theta[1] / theta[2];
58}

Once the Stan model is ready, we can construct the input dataset and start the sampling. I use the CmdStanR interface to run Stan which points to CmdStan on your computer (it needs to be installed separately; installation instruction here.)

1library(cmdstanr)
2library(posterior)
3
4data <- list(n_obs = n, n_pop = n_pop, y = y, t0 = 0, ts = times)
5sir_mod <- cmdstan_model("sir.stan")
6sir_fit <- sir_mod$sample(
7  data = data, 
8  seed = 1234,
9  chains = 4, 
10  parallel_chains = 4,
11  iter_warmup = 500,
12  iter_sampling = 500
13)
14
15theta <- as_draws_rvars(sir_fit$draws(variables = "theta"))

We check convergence and see if we recover our parameter values for $\beta = \text{theta[1]}$ and $\gamma = \text{theta[2]}$:

> summarise_draws(theta, default_convergence_measures())
# A tibble: 2 × 4
  variable  rhat ess_bulk ess_tail
  <chr>    <dbl>    <dbl>    <dbl>
1 theta[1]  1.00     574.     779.
2 theta[2]  1.00     993.     932.

summarise_draws(theta, ~ quantile(.x, probs = c(0.05, 0.5, 0.95)))
# A tibble: 2 × 4
  variable  `5%` `50%` `95%`
  <chr>    <dbl> <dbl> <dbl>
1 theta[1] 1.96  2.00  2.03 
2 theta[2] 0.488 0.494 0.502

This looks good but I also want to see if we captured our infection curve. For that, we need to extract the y_hat and plot them against the true curve from the solver.

1y_hat <- as_draws_rvars(sir_fit$draws(variables = "y_hat"))
2
3# helper function used for plotting the results
4make_df <- function(d, interval = .90, t, obs, pop) {
5  S <- d$y_hat[, 1] # Susceptible draws, not used
6  I <- d$y_hat[, 2] # Infected draws
7  R <- d$y_hat[, 3] # Recovered draws, not used
8  
9  # compute the uncertainty interval
10  low_quant <- (1 - interval) / 2
11  high_quant <- interval + low_quant
12  low <- apply(I, 2, quantile, probs = low_quant) * pop
13  high <- apply(I, 2, quantile, probs = high_quant) * pop
14  
15  return(tibble(low, high, times = t, obs))
16}
17
18d <- make_df(d = y_hat, interval = 0.90, times, obs = sol$I, pop = n_pop)
19ggplot(aes(times, obs), data = d) +
20  geom_ribbon(aes(ymin = low, ymax = high), fill = "grey70") +
21  geom_line(color = "red", linewidth = 0.3) +
22  xlab("Days") + ylab("Infections (90% Uncertainty)") +
23  ggtitle("SIR Estimation",
24          subtitle = "Red curve is the true rate")

The curve follows our posterior predictive distribution.

Out of sample predictions

Recovering a true curve from a simulation is a good first step but I want to check if this model can predict well out-of-sample. I am not going to get into what “well” means here but will instead follow the same visual strategy as before. In particular, I will hold out some of the data from the training set and predict forward into the test set and see if we can cover the out-of-sample curve perhaps with wider uncertainty. In particular, I want to know if the model can predict the full curve before observing the peak infection rate, which in our case is slightly greater than 6 days. This requires some work in Stan’s generated quantities block. Complete code listing can be accessed here [3].

1...
2data {
3...
4  int<lower=0> n_pred;        // number of cases to predict forward
5  array[n_pred] real ts_pred; // future time points
6}
7...
8generated quantities {
9  ...
10
11  vector[n_pred] y_pred;
12  vector[n_pred] lambda_pred;
13  
14  // New initial conditions
15  vector[3] y_init_pred = y_hat[n_obs, ]; 
16  
17  // New time zero is the last observed time
18  real t0_pred = ts[n_obs];
19  
20  array[n_pred] vector[3] y_hat_pred = ode_rk45(SIR, 
21                                                y_init_pred, 
22                                                t0_pred, 
23                                                ts_pred, 
24                                                theta);
25  for (i in 1:n_pred) {
26    lambda_pred[i] = y_hat_pred[i, 2] * n_pop;
27    y_pred[i] = poisson_rng(lambda_pred[i]);
28  }
29}

What’s left now is to construct the new input dataset, compile the new model, sample, extract parameters, and plot the results.

1pct_train <- 0.30
2n_train <- floor(n * pct_train)
3n_pred <- n - n_train
4times_pred <- times[(n_train + 1):n]
5y_train <- y[1:n_train]
6data <- list(n_obs = n_train, n_pred = n_pred, 
7             n_pop = n_pop, y = y_train, 
8             t0 = 0, ts = times[1:n_train], ts_pred = times_pred)
9
10
11sir_pred_mod <- cmdstan_model("sir-pred.stan")
12sir_pred_fit <- sir_pred_mod$sample(
13  data = data, 
14  seed = 1234,
15  chains = 4, 
16  parallel_chains = 4,
17  iter_warmup = 500,
18  iter_sampling = 500,
19  adapt_delta = 0.98
20)
21
22y_hat <- as_draws_rvars(sir_pred_fit$draws(variables = "y_hat"))
23y_hat_pred <- as_draws_rvars(sir_pred_fit$draws(variables = "y_hat_pred"))
24
25d_train <- make_df(y_hat, 0.90, data$ts, sol$I[1:n_train], n_pop)
26d_pred <- make_df(y_hat_pred, 0.90, data$ts_pred, sol$I[(n_train + 1):n], n_pop)
27d <- dplyr::bind_rows(d_train, d_pred) 
28
29ggplot(aes(times, obs), data = d) +
30  geom_ribbon(aes(ymin = low, ymax = high), 
31              fill = "grey70", alpha = 1/2) +
32  geom_line(color = "red", linewidth = 0.3) +
33  geom_vline(xintercept = d$times[n_train], linetype = "dotdash") +
34  xlab("Days") + ylab("Infections (90% Uncertainty)") +
35  ggtitle("SIR prediction trained on 5 days of data",
36          subtitle = "Red curve is the true rate")

The results look believable and at this point, we can start testing the model on real data. It would be a mistake, however, to assume that this model would be good enough to work well in the wild. Complete R code listing is available here [4].

Model improvements

Here is a partial list of model improvements and expansions that one would likely have to do in order to model real outbreaks.

  • Real count data rarely follow Poison distribution which forces the mean to be equal to the variance. A relatively simple change would be to replace it with a Negative Binomial likelihood and estimate overdispersion from data.
  • Large outbreaks tend to have regional patterns and so it would make sense to fit a multi-level model with partial pooling for each reporting location.
  • It is unlikely that the ODE dynamics alone are sufficient to describe the observed patterns and that baseline covariates should be included in the model. How do we learn which covariates to include? By consulting an epidemiologist familiar with the disease, of course.

References

[1] Chatzilena, A., van Leeuwen, E., Ratmann, O., Baguelin, M., & Demiris, N. (2019). Contemporary statistical inference for infectious disease models using Stan. https://arxiv.org/abs/1903.00423v3

[2] public-materials/blog/SIR/sir.stan at master · generable/public-materials. GitHub. Published 2025. Accessed July 19, 2025. https://github.com/generable/public-materials/blob/master/blog/SIR/sir.stan

‌[3] public-materials/blog/SIR/sir-pred.stan at master · generable/public-materials. GitHub. Published 2025. Accessed July 19, 2025. https://github.com/generable/public-materials/blob/master/blog/SIR/sir-pred.stan

‌[4] public-materials/blog/SIR/sir.R at master · generable/public-materials. GitHub. Published 2025. Accessed July 19, 2025. https://github.com/generable/public-materials/blob/master/blog/SIR/sir.R

Check other articles