Skip to content

Commit

Permalink
Day of the week tutorial (#344)
Browse files Browse the repository at this point in the history
* Adding example of weekday-effect back

* Adding the rst file

* Working on the tutorial

* wip [skip ci] expected to fail

* Reverting changegs

* Effect of size seven

* Passing the SampledValue instead of array to LatentHosp

* Adding argument to specify start day of the data in hospadmissions + test

* Reworking the tutorial (re-org)

* Adding day of the week to the index

* Removing old message

* Addressing comments by @damonbayer

* Addressing comments by @damonbayer

* inf_offset should always be within 0-6.

* Fixing bug in inf_offset

* Making pre-commit happy

* Combining personalized effect with deterministic

* Adding some details to the tutorial

* Fixing tutorial
  • Loading branch information
gvegayon authored Aug 21, 2024
1 parent f891cb7 commit 144d2b8
Show file tree
Hide file tree
Showing 8 changed files with 555 additions and 77 deletions.
353 changes: 353 additions & 0 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
---
title: Implementing a Day of the Week Effect
format: gfm
engine: jupyter
---

This document illustrates how to leverage the time-aware arrays to create a a day of the week effect. We use the same model designed in the hospital admissions-only tutorial.

## Recap: Hospital Admissions Model

In the ["Fitting a hospital admissions-only model" tutorial](https://cdcgov.github.io/multisignal-epi-inference/tutorials/hospital_admissions_model.html){target="_blank"}, we built a fairly complicated model that included pre-defined random variables as well as a custom random variable representing the reproductive number. In this tutorial, we will focus on adding a new component: the day-of-the-week effect. We start by reproducing the model without the day-of-the-week effect:

1. We load the data:

```{python}
# | label: setup
# | code-fold: true
# Setup
import numpyro
import polars as pl
from pyrenew import datasets
# Setting the number of devices
numpyro.set_host_device_count(2)
# Loading and processing the data
dat = (
datasets.load_wastewater()
.group_by("date")
.first()
.select(["date", "daily_hosp_admits"])
.sort("date")
.head(90)
)
daily_hosp_admits = dat["daily_hosp_admits"].to_numpy()
dates = dat["date"].to_numpy()
# Loading additional datasets
gen_int = datasets.load_generation_interval()
inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
```

2. Next, we defined the model's components:

```{python}
# | label: latent-hosp
# | code-fold: true
from pyrenew import latent, deterministic, metaclass
import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int
)
hosp_rate = metaclass.DistributionalRV(
name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)
latent_hosp = latent.HospitalAdmissions(
infection_to_admission_interval_rv=inf_hosp_int,
infection_hospitalization_ratio_rv=hosp_rate,
)
from pyrenew import model, process, observation, metaclass, transformation
from pyrenew.latent import (
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
)
# Infection process
latent_inf = latent.Infections()
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)
# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int)
```

including the Rt effect:

```{python}
# | label: Rt-process
# | code-fold: true
class MyRt(metaclass.RandomVariable):
def __init__(self, sd_rv):
self.sd_rv = sd_rv
def validate(self):
pass
def sample(self, n_steps: int, **kwargs) -> tuple:
# Standard deviation of the random walk
sd_rt, *_ = self.sd_rv()
# Random walk step
step_rv = metaclass.DistributionalRV(
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
)
init_rv = metaclass.DistributionalRV(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
# Random walk process
base_rv = process.SimpleRandomWalkProcess(
name="log_rt",
step_rv=step_rv,
init_rv=init_rv,
)
# Transforming the random walk to the Rt scale
rt_rv = metaclass.TransformedRandomVariable(
name="Rt_rv",
base_rv=base_rv,
transforms=transformation.ExpTransform(),
)
return rt_rv(n_steps=n_steps, **kwargs)
rtproc = MyRt(
metaclass.DistributionalRV(
name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025)
)
)
```

3. We defined the observation model:

```{python}
# | label: obs-model
# | code-fold: true
# we place a log-Normal prior on the concentration
# parameter of the negative binomial.
nb_conc_rv = metaclass.TransformedRandomVariable(
"concentration",
metaclass.DistributionalRV(
name="concentration_raw",
distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
transformation.PowerTransform(-2),
)
# now we define the observation process
obs = observation.NegativeBinomialObservation(
"negbinom_rv",
concentration_rv=nb_conc_rv,
)
```

4. And finally, we built the model:

```{python}
# | label: init-model
hosp_model = model.HospitalAdmissionsModel(
latent_infections_rv=latent_inf,
latent_hosp_admissions_rv=latent_hosp,
I0_rv=I0,
gen_int_rv=gen_int,
Rt_process_rv=rtproc,
hosp_admission_obs_process_rv=obs,
)
```

Here is what the model looks like without the day-of-the-week effect:

```{python}
# | label: fig-output-admissions-padding-and-weekday
# | fig-cap: Hospital Admissions posterior distribution without weekday effect
import jax
import numpy as np
# Model without weekday effect
hosp_model.run(
num_samples=2000,
num_warmup=2000,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False),
)
# Plotting the posterior
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
)
```


## Round 2: Incorporating day-of-the-week effects

We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect.

For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `RandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet`, and returns it multiplied by seven [^note-other-examples]:

[^note-other-examples]: A similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html).

```{python}
# | label: weekly-effect
# Initializing the RV.
from pyrenew.metaclass import RandomVariable, SampledValue
class MyDOWEffect(RandomVariable):
"""A personalized day-of-the-week effect."""
def __init__(self):
"""Initialize the Dirichlet distribution."""
self.rv = dist.Dirichlet(concentration=jnp.ones(7))
def validate(self):
"""
Since we don't have any parameters, we don't need to validate anything.
"""
pass
def sample(self, **kwargs) -> tuple:
"""Sample the day-of-the-week effect."""
# Sample the simplex
res = (
numpyro.sample(
"dayofweek_effect_simplex",
self.rv,
)
* 7
)
# We make sure to record the value multiplied by 7.
# numpyro's sample function records the raw value before
# the transformation.
numpyro.deterministic("dayofweek_effect", res)
# All `RandomVariable` sample methods should return a tuple
# of `SampledValue` instances.
return (SampledValue(res),)
# Instantiating the day-of-the-week effect
dayofweek_effect = MyDOWEffect()
```

Now, by re-defining the latent hospital admissions random variable with the day-of-the-week effect, we can build a model that includes this effect. Since the day-of-the-week effect takes into account the first day of the dataset, we need to determine the day of the week of the first observation. We can do this by converting the first date in the dataset to a `datetime` object and extracting the day of the week:

```{python}
# | label: latent-hosp-weekday
# Figuring out the day of the week of the first observation
import datetime as dt
first_dow_in_data = dates[0].astype(dt.datetime).weekday()
first_dow_in_data # zero
# Re-defining the latent hospital admissions RV, now with the
# day-of-the-week effect
latent_hosp_wday_effect = latent.HospitalAdmissions(
infection_to_admission_interval_rv=inf_hosp_int,
infection_hospitalization_ratio_rv=hosp_rate,
day_of_week_effect_rv=dayofweek_effect,
# Concidirently, this is zero
obs_data_first_day_of_the_week=first_dow_in_data,
)
# New model with day-of-the-week effect
hosp_model_dow = model.HospitalAdmissionsModel(
latent_infections_rv=latent_inf,
latent_hosp_admissions_rv=latent_hosp_wday_effect,
I0_rv=I0,
gen_int_rv=gen_int,
Rt_process_rv=rtproc,
hosp_admission_obs_process_rv=obs,
)
```

Running the model:

```{python}
# | label: model-2-run-weekday
# Model with weekday effect
hosp_model_dow.run(
num_samples=2000,
num_warmup=2000,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False),
)
```

As a result, we can see the posterior distribution of our novel day-of-the-week effect:

```{python}
# | label: fig-output-day-of-week
# | fig-cap: Day of the week effect
out = hosp_model_dow.plot_posterior(
var="dayofweek_effect", ylab="Day of the Week Effect", samples=500
)
sp = hosp_model_dow.spread_draws(["dayofweek_effect"])
```

The new model with the day-of-the-week effect can be compared to the previous model without the effect. Finally, let's reproduce the figure without the day-of-the-week effect, and then plot the new model with the effect:

```{python}
# | label: fig-output-admissions-original
# | fig-cap: Hospital Admissions posterior distribution without weekday effect
# Figure without weekday effect
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
)
```

```{python}
# | label: fig-output-admissions-wof
# | fig-cap: Hospital Admissions posterior distribution with weekday effect
# Figure with weekday effect
out_dow = hosp_model_dow.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
)
```
5 changes: 5 additions & 0 deletions docs/source/tutorials/day_of_the_week.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.. WARNING
.. Please do not edit this file directly.
.. This file is just a placeholder.
.. For the source file, see:
.. <https://github.com/CDCgov/multisignal-epi-inference/tree/main/docs/source/tutorials/day_of_the_week.qmd>
2 changes: 1 addition & 1 deletion docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ hosp_rate = metaclass.DistributionalRV(
latent_hosp = latent.HospitalAdmissions(
infection_to_admission_interval_rv=inf_hosp_int,
infect_hosp_rate_rv=hosp_rate,
infection_hospitalization_ratio_rv=hosp_rate,
)
```

Expand Down
1 change: 1 addition & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ This section contains tutorials that demonstrate how to use the `pyrenew` packag
extending_pyrenew
periodic_effects
time
day_of_the_week
Loading

0 comments on commit 144d2b8

Please sign in to comment.