Skip to content

Commit

Permalink
Using TransformedRandomVariable (#406)
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon authored Aug 21, 2024
1 parent 144d2b8 commit 7ad81f1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 42 deletions.
52 changes: 11 additions & 41 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -214,53 +214,23 @@ out = hosp_model.plot_posterior(

We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect.

For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `RandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet`, and returns it multiplied by seven [^note-other-examples]:
For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedRandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]:

[^note-other-examples]: A similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html).

```{python}
# | label: weekly-effect
# Initializing the RV.
from pyrenew.metaclass import RandomVariable, SampledValue
class MyDOWEffect(RandomVariable):
"""A personalized day-of-the-week effect."""
def __init__(self):
"""Initialize the Dirichlet distribution."""
self.rv = dist.Dirichlet(concentration=jnp.ones(7))
def validate(self):
"""
Since we don't have any parameters, we don't need to validate anything.
"""
pass
def sample(self, **kwargs) -> tuple:
"""Sample the day-of-the-week effect."""
# Sample the simplex
res = (
numpyro.sample(
"dayofweek_effect_simplex",
self.rv,
)
* 7
)
# We make sure to record the value multiplied by 7.
# numpyro's sample function records the raw value before
# the transformation.
numpyro.deterministic("dayofweek_effect", res)
# All `RandomVariable` sample methods should return a tuple
# of `SampledValue` instances.
return (SampledValue(res),)
# Instantiating the day-of-the-week effect
dayofweek_effect = MyDOWEffect()
dayofweek_effect = metaclass.TransformedRandomVariable(
name="dayofweek_effect",
base_rv=metaclass.DistributionalRV(
name="dayofweek_effect_raw",
distribution=dist.Dirichlet(jnp.ones(7)),
),
transforms=transformation.AffineTransform(
loc=0, scale=7, domain=jnp.array([0, 1])
),
)
```

Now, by re-defining the latent hospital admissions random variable with the day-of-the-week effect, we can build a model that includes this effect. Since the day-of-the-week effect takes into account the first day of the dataset, we need to determine the day of the week of the first observation. We can do this by converting the first date in the dataset to a `datetime` object and extracting the day of the week:
Expand Down
4 changes: 3 additions & 1 deletion src/pyrenew/latent/hospitaladmissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def sample(
# 1. Get the day of the week effect
# 2. Identify the offset of the latent_infections
# 3. Apply the day of the week effect to the latent_hospital_admissions
dow_effect_sampled = self.day_of_week_effect_rv(**kwargs)[0]
dow_effect_sampled = self.day_of_week_effect_rv(**kwargs, record=True)[
0
]

if dow_effect_sampled.value.size != 7:
raise ValueError(
Expand Down

0 comments on commit 7ad81f1

Please sign in to comment.