From 144d2b811222d7a7d8126b2c0d93621342b5e85f Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Wed, 21 Aug 2024 09:35:43 -0600 Subject: [PATCH] Day of the week tutorial (#344) * Adding example of weekday-effect back * Adding the rst file * Working on the tutorial * wip [skip ci] expected to fail * Reverting changegs * Effect of size seven * Passing the SampledValue instead of array to LatentHosp * Adding argument to specify start day of the data in hospadmissions + test * Reworking the tutorial (re-org) * Adding day of the week to the index * Removing old message * Addressing comments by @damonbayer * Addressing comments by @damonbayer * inf_offset should always be within 0-6. * Fixing bug in inf_offset * Making pre-commit happy * Combining personalized effect with deterministic * Adding some details to the tutorial * Fixing tutorial --- docs/source/tutorials/day_of_the_week.qmd | 353 ++++++++++++++++++ docs/source/tutorials/day_of_the_week.rst | 5 + .../tutorials/hospital_admissions_model.qmd | 2 +- docs/source/tutorials/index.rst | 1 + src/pyrenew/latent/hospitaladmissions.py | 142 ++++--- src/pyrenew/model/admissionsmodel.py | 2 +- src/test/test_latent_admissions.py | 85 ++++- src/test/test_model_hosp_admissions.py | 42 +-- 8 files changed, 555 insertions(+), 77 deletions(-) create mode 100644 docs/source/tutorials/day_of_the_week.qmd create mode 100644 docs/source/tutorials/day_of_the_week.rst diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd new file mode 100644 index 00000000..ceadc985 --- /dev/null +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -0,0 +1,353 @@ +--- +title: Implementing a Day of the Week Effect +format: gfm +engine: jupyter +--- + +This document illustrates how to leverage the time-aware arrays to create a a day of the week effect. We use the same model designed in the hospital admissions-only tutorial. + +## Recap: Hospital Admissions Model + +In the ["Fitting a hospital admissions-only model" tutorial](https://cdcgov.github.io/multisignal-epi-inference/tutorials/hospital_admissions_model.html){target="_blank"}, we built a fairly complicated model that included pre-defined random variables as well as a custom random variable representing the reproductive number. In this tutorial, we will focus on adding a new component: the day-of-the-week effect. We start by reproducing the model without the day-of-the-week effect: + +1. We load the data: + +```{python} +# | label: setup +# | code-fold: true +# Setup +import numpyro +import polars as pl +from pyrenew import datasets + +# Setting the number of devices +numpyro.set_host_device_count(2) + +# Loading and processing the data +dat = ( + datasets.load_wastewater() + .group_by("date") + .first() + .select(["date", "daily_hosp_admits"]) + .sort("date") + .head(90) +) + +daily_hosp_admits = dat["daily_hosp_admits"].to_numpy() +dates = dat["date"].to_numpy() + +# Loading additional datasets +gen_int = datasets.load_generation_interval() +inf_hosp_int = datasets.load_infection_admission_interval() + +# We only need the probability_mass column of each dataset +gen_int_array = gen_int["probability_mass"].to_numpy() +gen_int = gen_int_array +inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy() +``` + +2. Next, we defined the model's components: + +```{python} +# | label: latent-hosp +# | code-fold: true +from pyrenew import latent, deterministic, metaclass +import jax.numpy as jnp +import numpyro.distributions as dist + +inf_hosp_int = deterministic.DeterministicPMF( + name="inf_hosp_int", value=inf_hosp_int +) + +hosp_rate = metaclass.DistributionalRV( + name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) +) + +latent_hosp = latent.HospitalAdmissions( + infection_to_admission_interval_rv=inf_hosp_int, + infection_hospitalization_ratio_rv=hosp_rate, +) + +from pyrenew import model, process, observation, metaclass, transformation +from pyrenew.latent import ( + InfectionInitializationProcess, + InitializeInfectionsExponentialGrowth, +) + + +# Infection process +latent_inf = latent.Infections() +I0 = InfectionInitializationProcess( + "I0_initialization", + metaclass.DistributionalRV( + name="I0", + distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), + ), + InitializeInfectionsExponentialGrowth( + gen_int_array.size, + deterministic.DeterministicVariable(name="rate", value=0.05), + ), + t_unit=1, +) + +# Generation interval and Rt +gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int) +``` + +including the Rt effect: + +```{python} +# | label: Rt-process +# | code-fold: true +class MyRt(metaclass.RandomVariable): + def __init__(self, sd_rv): + self.sd_rv = sd_rv + + def validate(self): + pass + + def sample(self, n_steps: int, **kwargs) -> tuple: + # Standard deviation of the random walk + sd_rt, *_ = self.sd_rv() + + # Random walk step + step_rv = metaclass.DistributionalRV( + name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value) + ) + + init_rv = metaclass.DistributionalRV( + name="init_log_rt", distribution=dist.Normal(0, 0.2) + ) + + # Random walk process + base_rv = process.SimpleRandomWalkProcess( + name="log_rt", + step_rv=step_rv, + init_rv=init_rv, + ) + + # Transforming the random walk to the Rt scale + rt_rv = metaclass.TransformedRandomVariable( + name="Rt_rv", + base_rv=base_rv, + transforms=transformation.ExpTransform(), + ) + + return rt_rv(n_steps=n_steps, **kwargs) + + +rtproc = MyRt( + metaclass.DistributionalRV( + name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025) + ) +) +``` + +3. We defined the observation model: + +```{python} +# | label: obs-model +# | code-fold: true +# we place a log-Normal prior on the concentration +# parameter of the negative binomial. +nb_conc_rv = metaclass.TransformedRandomVariable( + "concentration", + metaclass.DistributionalRV( + name="concentration_raw", + distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), + ), + transformation.PowerTransform(-2), +) + +# now we define the observation process +obs = observation.NegativeBinomialObservation( + "negbinom_rv", + concentration_rv=nb_conc_rv, +) +``` + +4. And finally, we built the model: + +```{python} +# | label: init-model +hosp_model = model.HospitalAdmissionsModel( + latent_infections_rv=latent_inf, + latent_hosp_admissions_rv=latent_hosp, + I0_rv=I0, + gen_int_rv=gen_int, + Rt_process_rv=rtproc, + hosp_admission_obs_process_rv=obs, +) +``` + +Here is what the model looks like without the day-of-the-week effect: + +```{python} +# | label: fig-output-admissions-padding-and-weekday +# | fig-cap: Hospital Admissions posterior distribution without weekday effect +import jax +import numpy as np + +# Model without weekday effect +hosp_model.run( + num_samples=2000, + num_warmup=2000, + data_observed_hosp_admissions=daily_hosp_admits, + rng_key=jax.random.key(54), + mcmc_args=dict(progress_bar=False), +) + +# Plotting the posterior +out = hosp_model.plot_posterior( + var="latent_hospital_admissions", + ylab="Hospital Admissions", + obs_signal=np.pad( + daily_hosp_admits.astype(float), + (gen_int_array.size, 0), + constant_values=np.nan, + ), +) +``` + + +## Round 2: Incorporating day-of-the-week effects + +We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect. + +For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `RandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet`, and returns it multiplied by seven [^note-other-examples]: + +[^note-other-examples]: A similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html). + +```{python} +# | label: weekly-effect +# Initializing the RV. +from pyrenew.metaclass import RandomVariable, SampledValue + + +class MyDOWEffect(RandomVariable): + """A personalized day-of-the-week effect.""" + + def __init__(self): + """Initialize the Dirichlet distribution.""" + self.rv = dist.Dirichlet(concentration=jnp.ones(7)) + + def validate(self): + """ + Since we don't have any parameters, we don't need to validate anything. + """ + pass + + def sample(self, **kwargs) -> tuple: + """Sample the day-of-the-week effect.""" + + # Sample the simplex + res = ( + numpyro.sample( + "dayofweek_effect_simplex", + self.rv, + ) + * 7 + ) + + # We make sure to record the value multiplied by 7. + # numpyro's sample function records the raw value before + # the transformation. + numpyro.deterministic("dayofweek_effect", res) + + # All `RandomVariable` sample methods should return a tuple + # of `SampledValue` instances. + return (SampledValue(res),) + + +# Instantiating the day-of-the-week effect +dayofweek_effect = MyDOWEffect() +``` + +Now, by re-defining the latent hospital admissions random variable with the day-of-the-week effect, we can build a model that includes this effect. Since the day-of-the-week effect takes into account the first day of the dataset, we need to determine the day of the week of the first observation. We can do this by converting the first date in the dataset to a `datetime` object and extracting the day of the week: + +```{python} +# | label: latent-hosp-weekday +# Figuring out the day of the week of the first observation +import datetime as dt + +first_dow_in_data = dates[0].astype(dt.datetime).weekday() +first_dow_in_data # zero + +# Re-defining the latent hospital admissions RV, now with the +# day-of-the-week effect +latent_hosp_wday_effect = latent.HospitalAdmissions( + infection_to_admission_interval_rv=inf_hosp_int, + infection_hospitalization_ratio_rv=hosp_rate, + day_of_week_effect_rv=dayofweek_effect, + # Concidirently, this is zero + obs_data_first_day_of_the_week=first_dow_in_data, +) + +# New model with day-of-the-week effect +hosp_model_dow = model.HospitalAdmissionsModel( + latent_infections_rv=latent_inf, + latent_hosp_admissions_rv=latent_hosp_wday_effect, + I0_rv=I0, + gen_int_rv=gen_int, + Rt_process_rv=rtproc, + hosp_admission_obs_process_rv=obs, +) +``` + +Running the model: + +```{python} +# | label: model-2-run-weekday +# Model with weekday effect +hosp_model_dow.run( + num_samples=2000, + num_warmup=2000, + data_observed_hosp_admissions=daily_hosp_admits, + rng_key=jax.random.key(54), + mcmc_args=dict(progress_bar=False), +) +``` + +As a result, we can see the posterior distribution of our novel day-of-the-week effect: + +```{python} +# | label: fig-output-day-of-week +# | fig-cap: Day of the week effect +out = hosp_model_dow.plot_posterior( + var="dayofweek_effect", ylab="Day of the Week Effect", samples=500 +) + +sp = hosp_model_dow.spread_draws(["dayofweek_effect"]) +``` + +The new model with the day-of-the-week effect can be compared to the previous model without the effect. Finally, let's reproduce the figure without the day-of-the-week effect, and then plot the new model with the effect: + +```{python} +# | label: fig-output-admissions-original +# | fig-cap: Hospital Admissions posterior distribution without weekday effect +# Figure without weekday effect +out = hosp_model.plot_posterior( + var="latent_hospital_admissions", + ylab="Hospital Admissions", + obs_signal=np.pad( + daily_hosp_admits.astype(float), + (gen_int_array.size, 0), + constant_values=np.nan, + ), +) +``` + +```{python} +# | label: fig-output-admissions-wof +# | fig-cap: Hospital Admissions posterior distribution with weekday effect +# Figure with weekday effect +out_dow = hosp_model_dow.plot_posterior( + var="latent_hospital_admissions", + ylab="Hospital Admissions", + obs_signal=np.pad( + daily_hosp_admits.astype(float), + (gen_int_array.size, 0), + constant_values=np.nan, + ), +) +``` diff --git a/docs/source/tutorials/day_of_the_week.rst b/docs/source/tutorials/day_of_the_week.rst new file mode 100644 index 00000000..ef1ea658 --- /dev/null +++ b/docs/source/tutorials/day_of_the_week.rst @@ -0,0 +1,5 @@ +.. WARNING +.. Please do not edit this file directly. +.. This file is just a placeholder. +.. For the source file, see: +.. diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index cbd73e39..0889782f 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -151,7 +151,7 @@ hosp_rate = metaclass.DistributionalRV( latent_hosp = latent.HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp_int, - infect_hosp_rate_rv=hosp_rate, + infection_hospitalization_ratio_rv=hosp_rate, ) ``` diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst index d7b6c0ac..11b0f958 100644 --- a/docs/source/tutorials/index.rst +++ b/docs/source/tutorials/index.rst @@ -12,3 +12,4 @@ This section contains tutorials that demonstrate how to use the `pyrenew` packag extending_pyrenew periodic_effects time + day_of_the_week diff --git a/src/pyrenew/latent/hospitaladmissions.py b/src/pyrenew/latent/hospitaladmissions.py index 0a8ea564..3d1a1126 100644 --- a/src/pyrenew/latent/hospitaladmissions.py +++ b/src/pyrenew/latent/hospitaladmissions.py @@ -7,8 +7,8 @@ import jax.numpy as jnp import numpyro -from jax.typing import ArrayLike +import pyrenew.arrayutils as au from pyrenew.deterministic import DeterministicVariable from pyrenew.metaclass import RandomVariable, SampledValue @@ -23,13 +23,18 @@ class HospitalAdmissionsSample(NamedTuple): The infection-to-hospitalization rate. Defaults to None. latent_hospital_admissions : SampledValue or None The computed number of hospital admissions. Defaults to None. + multiplier : SampledValue or None + The day of the week effect multiplier. Defaults to None. It + should match the number of timepoints in the latent hospital + admissions. """ infection_hosp_rate: SampledValue | None = None latent_hospital_admissions: SampledValue | None = None + multiplier: SampledValue | None = None def __repr__(self): - return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions})" + return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions}, multiplier={self.multiplier})" class HospitalAdmissions(RandomVariable): @@ -65,9 +70,10 @@ class HospitalAdmissions(RandomVariable): def __init__( self, infection_to_admission_interval_rv: RandomVariable, - infect_hosp_rate_rv: RandomVariable, + infection_hospitalization_ratio_rv: RandomVariable, day_of_week_effect_rv: RandomVariable | None = None, - hosp_report_prob_rv: RandomVariable | None = None, + hospitalization_reporting_ratio_rv: RandomVariable | None = None, + obs_data_first_day_of_the_week: int = 0, ) -> None: """ Default constructor @@ -77,13 +83,19 @@ def __init__( infection_to_admission_interval_rv : RandomVariable pmf for reporting (informing) hospital admissions (see pyrenew.observations.Deterministic). - infect_hosp_rate_rv : RandomVariable + infection_hospitalization_ratio_rv : RandomVariable Infection to hospitalization rate random variable. day_of_week_effect_rv : RandomVariable, optional - Day of the week effect. - hosp_report_prob_rv : RandomVariable, optional + Day of the week effect. Should return a SampledValue with 7 + values. Defaults to a deterministic variable with + jax.numpy.ones(7) (no effect). + hospitalization_reporting_ratio_rv : RandomVariable, optional Random variable for the hospital admission reporting probability. Defaults to 1 (full reporting). + obs_data_first_day_of_the_week : int, optional + The day of the week that the first day of the observation data + corresponds to. Valid values are 0-6, where 0 is Monday and 6 is + Sunday. Defaults to 0. Returns ------- @@ -92,33 +104,40 @@ def __init__( if day_of_week_effect_rv is None: day_of_week_effect_rv = DeterministicVariable( - name="weekday_effect", value=1 + name="weekday_effect", value=jnp.ones(7) ) - if hosp_report_prob_rv is None: - hosp_report_prob_rv = DeterministicVariable( - name="hosp_report_prob", value=1 + if hospitalization_reporting_ratio_rv is None: + hospitalization_reporting_ratio_rv = DeterministicVariable( + name="hosp_report_prob", value=1.0 ) HospitalAdmissions.validate( - infect_hosp_rate_rv, - day_of_week_effect_rv, - hosp_report_prob_rv, infection_to_admission_interval_rv, + infection_hospitalization_ratio_rv, + day_of_week_effect_rv, + hospitalization_reporting_ratio_rv, + obs_data_first_day_of_the_week, ) - self.infect_hosp_rate_rv = infect_hosp_rate_rv - self.day_of_week_effect_rv = day_of_week_effect_rv - self.hosp_report_prob_rv = hosp_report_prob_rv self.infection_to_admission_interval_rv = ( infection_to_admission_interval_rv ) + self.infection_hospitalization_ratio_rv = ( + infection_hospitalization_ratio_rv + ) + self.day_of_week_effect_rv = day_of_week_effect_rv + self.hospitalization_reporting_ratio_rv = ( + hospitalization_reporting_ratio_rv + ) + self.obs_data_first_day_of_the_week = obs_data_first_day_of_the_week @staticmethod def validate( - infect_hosp_rate_rv: Any, - day_of_week_effect_rv: Any, - hosp_report_prob_rv: Any, infection_to_admission_interval_rv: Any, + infection_hospitalization_ratio_rv: Any, + day_of_week_effect_rv: Any, + hospitalization_reporting_ratio_rv: Any, + obs_data_first_day_of_the_week: Any, ) -> None: """ Validates that the IHR, weekday effects, probability of being @@ -127,16 +146,20 @@ def validate( Parameters ---------- - infect_hosp_rate_rv : Any + infection_to_admission_interval_rv : Any + Possibly incorrect input for the infection to hospitalization + interval distribution. + infection_hospitalization_ratio_rv : Any Possibly incorrect input for infection to hospitalization rate distribution. day_of_week_effect_rv : Any Possibly incorrect input for day of the week effect. - hosp_report_prob_rv : Any + hospitalization_reporting_ratio_rv : Any Possibly incorrect input for distribution or fixed value for the hospital admission reporting probability. - infection_to_admission_interval_rv : Any - Possibly incorrect input for hospital admissions - reporting delay interval pmf. + obs_data_first_day_of_the_week : Any + Possibly incorrect input for the day of the week that the first day + of the observation data corresponds to. Valid values are 0-6, where + 0 is Monday and 6 is Sunday. Returns ------- @@ -145,19 +168,21 @@ def validate( Raises ------ AssertionError - If the object `distr` is not an instance of `dist.Distribution`, indicating - that the validation has failed. + If any of the random variables are not of the correct type, or if + the day of the week is not within the valid range. """ - assert isinstance(infect_hosp_rate_rv, RandomVariable) - assert isinstance(day_of_week_effect_rv, RandomVariable) - assert isinstance(hosp_report_prob_rv, RandomVariable) assert isinstance(infection_to_admission_interval_rv, RandomVariable) + assert isinstance(infection_hospitalization_ratio_rv, RandomVariable) + assert isinstance(day_of_week_effect_rv, RandomVariable) + assert isinstance(hospitalization_reporting_ratio_rv, RandomVariable) + assert isinstance(obs_data_first_day_of_the_week, int) + assert 0 <= obs_data_first_day_of_the_week <= 6 return None def sample( self, - latent_infections: ArrayLike, + latent_infections: SampledValue, **kwargs, ) -> HospitalAdmissionsSample: """ @@ -165,8 +190,8 @@ def sample( Parameters ---------- - latent_infections : ArrayLike - Latent infections. + latent_infections : SampledValue + Latent infections. Possibly the output of the `latent.Infections()`. **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` calls, should there be any. @@ -176,10 +201,8 @@ def sample( HospitalAdmissionsSample """ - infection_hosp_rate, *_ = self.infect_hosp_rate_rv(**kwargs) - - latent_hospital_admissions_raw = ( - infection_hosp_rate.value * latent_infections + infection_hosp_rate, *_ = self.infection_hospitalization_ratio_rv( + **kwargs ) ( @@ -188,23 +211,45 @@ def sample( ) = self.infection_to_admission_interval_rv(**kwargs) latent_hospital_admissions = jnp.convolve( - latent_hospital_admissions_raw, + infection_hosp_rate.value * latent_infections.value, infection_to_admission_interval.value, mode="full", - )[: latent_hospital_admissions_raw.shape[0]] + )[: latent_infections.value.shape[0]] + + # Applying the day of the week effect. For this we need to: + # 1. Get the day of the week effect + # 2. Identify the offset of the latent_infections + # 3. Apply the day of the week effect to the latent_hospital_admissions + dow_effect_sampled = self.day_of_week_effect_rv(**kwargs)[0] + + if dow_effect_sampled.value.size != 7: + raise ValueError( + "Day of the week effect should have 7 values. " + f"Got {dow_effect_sampled.value.size} instead." + ) - # Applying the day of the week effect - latent_hospital_admissions = ( - latent_hospital_admissions - * self.day_of_week_effect_rv( - n_timepoints=latent_hospital_admissions.size, **kwargs - )[0].value + # Identifying the offset + if latent_infections.t_start is None: + inf_offset = 0 + else: + inf_offset = latent_infections.t_start + + inf_offset = (inf_offset + self.obs_data_first_day_of_the_week) % 7 + + # Replicating the day of the week effect to match the number of + # timepoints + dow_effect = au.tile_until_n( + data=dow_effect_sampled.value, + n_timepoints=latent_hospital_admissions.size, + offset=inf_offset, ) + latent_hospital_admissions = latent_hospital_admissions * dow_effect + # Applying reporting probability latent_hospital_admissions = ( latent_hospital_admissions - * self.hosp_report_prob_rv(**kwargs)[0].value + * self.hospitalization_reporting_ratio_rv(**kwargs)[0].value ) numpyro.deterministic( @@ -218,4 +263,9 @@ def sample( t_start=self.t_start, t_unit=self.t_unit, ), + multiplier=SampledValue( + dow_effect, + t_start=self.t_start, + t_unit=self.t_unit, + ), ) diff --git a/src/pyrenew/model/admissionsmodel.py b/src/pyrenew/model/admissionsmodel.py index 7977360b..566783b0 100644 --- a/src/pyrenew/model/admissionsmodel.py +++ b/src/pyrenew/model/admissionsmodel.py @@ -201,7 +201,7 @@ def sample( latent_hosp_admissions, *_, ) = self.latent_hosp_admissions_rv( - latent_infections=basic_model.latent_infections.value, + latent_infections=basic_model.latent_infections, **kwargs, ) diff --git a/src/test/test_latent_admissions.py b/src/test/test_latent_admissions.py index af0034d6..92acec2b 100644 --- a/src/test/test_latent_admissions.py +++ b/src/test/test_latent_admissions.py @@ -1,16 +1,20 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 -from test.utils import simple_rt - import jax.numpy as jnp import numpy.testing as testing import numpyro import numpyro.distributions as dist -from pyrenew.deterministic import DeterministicPMF +from pyrenew import transformation as t +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import ( + DistributionalRV, + SampledValue, + TransformedRandomVariable, +) +from pyrenew.process import SimpleRandomWalkProcess def test_admissions_sample(): @@ -21,7 +25,19 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions - rt = simple_rt() + rt = TransformedRandomVariable( + name="Rt_rv", + base_rv=SimpleRandomWalkProcess( + name="log_rt", + step_rv=DistributionalRV( + name="rw_step_rv", distribution=dist.Normal(0, 0.025) + ), + init_rv=DistributionalRV( + name="init_log_rt", distribution=dist.Normal(0, 0.2) + ), + ), + transforms=t.ExpTransform(), + ) with numpyro.handlers.seed(rng_seed=223): sim_rt = rt(n_steps=30)[0].value @@ -63,15 +79,70 @@ def test_admissions_sample(): hosp1 = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infect_hosp_rate_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalRV( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) with numpyro.handlers.seed(rng_seed=223): - sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0].value) + sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0]) testing.assert_array_less( sim_hosp_1.latent_hospital_admissions.value, inf_sampled1[0].value, ) + + # Testing the offset in the observed data + inf_hosp2 = jnp.ones(30) + inf_hosp2 = DeterministicPMF("i2h", inf_hosp2 / sum(inf_hosp2)) + + dow_effect = jnp.array([1, 1, 1, 1, 0.5, 0.5, 0.5]) + dow_effect = DeterministicPMF( + name="dow_effect", + value=dow_effect / sum(dow_effect), + ) + + dow_effect_wrong = DeterministicPMF( + name="dow_effect", + value=jnp.array([0.3, 0.3, 1 - 0.6]), + ) + hosp2a = HospitalAdmissions( + infection_to_admission_interval_rv=inf_hosp2, + infection_hospitalization_ratio_rv=DeterministicVariable("ihr", 1), + day_of_week_effect_rv=dow_effect, + obs_data_first_day_of_the_week=0, + ) + + hosp2b = HospitalAdmissions( + infection_to_admission_interval_rv=inf_hosp2, + infection_hospitalization_ratio_rv=DeterministicVariable("ihr", 1), + day_of_week_effect_rv=dow_effect, + obs_data_first_day_of_the_week=2, + ) + + hosp3b = HospitalAdmissions( + infection_to_admission_interval_rv=inf_hosp2, + infection_hospitalization_ratio_rv=DeterministicVariable("ihr", 1), + day_of_week_effect_rv=dow_effect_wrong, + obs_data_first_day_of_the_week=2, + ) + + inf_sampled2 = SampledValue(jnp.ones(30)) + + with numpyro.handlers.seed(rng_seed=223): + sim_hosp_2a = hosp2a(latent_infections=inf_sampled2).multiplier.value + + with numpyro.handlers.seed(rng_seed=223): + sim_hosp_2b = hosp2b(latent_infections=inf_sampled2).multiplier.value + + with numpyro.handlers.seed(rng_seed=223): + with testing.assert_raises(ValueError): + hosp3b(latent_infections=inf_sampled2).multiplier.value + + testing.assert_array_equal( + sim_hosp_2a[2 : (sim_hosp_2b.size - 2)], + sim_hosp_2b[: (sim_hosp_2b.size - 4)], + ) + + +test_admissions_sample() diff --git a/src/test/test_model_hosp_admissions.py b/src/test/test_model_hosp_admissions.py index 4c67ff69..fc2e4f57 100644 --- a/src/test/test_model_hosp_admissions.py +++ b/src/test/test_model_hosp_admissions.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 - from test.utils import simple_rt import jax.numpy as jnp @@ -29,7 +28,12 @@ class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08 - def __init__(self, pname: str): # numpydoc ignore=GL08 + def __init__( + self, + size: int, + pname: str, + ): # numpydoc ignore=GL08 + self.size = size self.name = pname return None @@ -42,7 +46,9 @@ def sample(self, **kwargs): # numpydoc ignore=GL08 return ( SampledValue( numpyro.sample( - name=self.name, fn=dist.Uniform(high=0.99, low=0.01) + name=self.name, + fn=dist.Uniform(high=0.99, low=0.01), + sample_shape=(self.size,), ) ), ) @@ -93,7 +99,7 @@ def test_model_hosp_no_timepoints_or_observations(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infect_hosp_rate_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalRV( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -158,7 +164,7 @@ def test_model_hosp_both_timepoints_and_observations(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infect_hosp_rate_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalRV( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -229,7 +235,7 @@ def test_model_hosp_no_obs_model(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infect_hosp_rate_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalRV( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), @@ -340,7 +346,7 @@ def test_model_hosp_with_obs_model(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infect_hosp_rate_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalRV( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), @@ -425,20 +431,14 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): ), ) - # Other random components - weekday = jnp.array([1, 1, 1, 1, 2, 2]) - weekday = weekday / weekday.sum() - weekday = jnp.tile(weekday, 10) - weekday = weekday[:31] - - hosp_report_prob_dist = UniformProbForTest("hosp_report_prob_dist") - weekday = UniformProbForTest("weekday") + hosp_report_prob_dist = UniformProbForTest(1, "hosp_report_prob_dist") + weekday = UniformProbForTest(7, "weekday") latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, day_of_week_effect_rv=weekday, - hosp_report_prob_rv=hosp_report_prob_dist, - infect_hosp_rate_rv=DistributionalRV( + hospitalization_reporting_ratio_rv=hosp_report_prob_dist, + infection_hospitalization_ratio_rv=DistributionalRV( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -527,10 +527,8 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # Other random components total_length = n_obs_to_generate + pad_size + gen_int.size() - weekday = jnp.array([1, 1, 1, 1, 2, 2]) + weekday = jnp.array([1, 1, 1, 1, 2, 2, 2]) weekday = weekday / weekday.sum() - weekday = jnp.tile(weekday, 10) - weekday = weekday[:total_length] weekday = DeterministicVariable(name="weekday", value=weekday) @@ -547,8 +545,8 @@ def test_model_hosp_with_obs_model_weekday_phosp(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, day_of_week_effect_rv=weekday, - hosp_report_prob_rv=hosp_report_prob_dist, - infect_hosp_rate_rv=DistributionalRV( + hospitalization_reporting_ratio_rv=hosp_report_prob_dist, + infection_hospitalization_ratio_rv=DistributionalRV( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ),