From 7ad81f16141fc4fed6cd07ee994286da20649c5a Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Wed, 21 Aug 2024 15:17:20 -0600 Subject: [PATCH] Using TransformedRandomVariable (#406) --- docs/source/tutorials/day_of_the_week.qmd | 52 +++++------------------ src/pyrenew/latent/hospitaladmissions.py | 4 +- 2 files changed, 14 insertions(+), 42 deletions(-) diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index ceadc985..4bdb110b 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -214,53 +214,23 @@ out = hosp_model.plot_posterior( We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect. -For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `RandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet`, and returns it multiplied by seven [^note-other-examples]: +For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedRandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]: [^note-other-examples]: A similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html). ```{python} # | label: weekly-effect -# Initializing the RV. -from pyrenew.metaclass import RandomVariable, SampledValue - - -class MyDOWEffect(RandomVariable): - """A personalized day-of-the-week effect.""" - - def __init__(self): - """Initialize the Dirichlet distribution.""" - self.rv = dist.Dirichlet(concentration=jnp.ones(7)) - - def validate(self): - """ - Since we don't have any parameters, we don't need to validate anything. - """ - pass - - def sample(self, **kwargs) -> tuple: - """Sample the day-of-the-week effect.""" - - # Sample the simplex - res = ( - numpyro.sample( - "dayofweek_effect_simplex", - self.rv, - ) - * 7 - ) - - # We make sure to record the value multiplied by 7. - # numpyro's sample function records the raw value before - # the transformation. - numpyro.deterministic("dayofweek_effect", res) - - # All `RandomVariable` sample methods should return a tuple - # of `SampledValue` instances. - return (SampledValue(res),) - - # Instantiating the day-of-the-week effect -dayofweek_effect = MyDOWEffect() +dayofweek_effect = metaclass.TransformedRandomVariable( + name="dayofweek_effect", + base_rv=metaclass.DistributionalRV( + name="dayofweek_effect_raw", + distribution=dist.Dirichlet(jnp.ones(7)), + ), + transforms=transformation.AffineTransform( + loc=0, scale=7, domain=jnp.array([0, 1]) + ), +) ``` Now, by re-defining the latent hospital admissions random variable with the day-of-the-week effect, we can build a model that includes this effect. Since the day-of-the-week effect takes into account the first day of the dataset, we need to determine the day of the week of the first observation. We can do this by converting the first date in the dataset to a `datetime` object and extracting the day of the week: diff --git a/src/pyrenew/latent/hospitaladmissions.py b/src/pyrenew/latent/hospitaladmissions.py index 3d1a1126..f283e861 100644 --- a/src/pyrenew/latent/hospitaladmissions.py +++ b/src/pyrenew/latent/hospitaladmissions.py @@ -220,7 +220,9 @@ def sample( # 1. Get the day of the week effect # 2. Identify the offset of the latent_infections # 3. Apply the day of the week effect to the latent_hospital_admissions - dow_effect_sampled = self.day_of_week_effect_rv(**kwargs)[0] + dow_effect_sampled = self.day_of_week_effect_rv(**kwargs, record=True)[ + 0 + ] if dow_effect_sampled.value.size != 7: raise ValueError(