diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index b06de44a..e9a3fcba 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -16,7 +16,7 @@ import jax.numpy as jnp import numpy as np import numpyro import numpyro.distributions as dist -from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.process import RandomWalk from pyrenew.latent import ( Infections, InfectionInitializationProcess, @@ -68,7 +68,7 @@ flowchart LR end subgraph process[Process module] - rt["Rt_process_rv\n(Custom class built using SimpleRandomWalk)"] + rt["Rt_process_rv\n(Custom class built using RandomWalk)"] end subgraph deterministic[Deterministic module] @@ -139,26 +139,25 @@ class MyRt(RandomVariable): def validate(self): pass - def sample(self, n_steps: int, **kwargs) -> tuple: + def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) rt_rv = TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( + name="log_rt_random_walk", + base_rv=RandomWalk( name="log_rt", step_rv=DistributionalRV( - name="rw_step_rv", - distribution=dist.Normal(0, sd_rt), - reparam=LocScaleReparam(0), - ), - init_rv=DistributionalRV( - name="init_log_rt", - distribution=dist.Normal(jnp.log(1), jnp.log(1.2)), + name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=t.ExpTransform(), ) - return rt_rv.sample(n_steps=n_steps, **kwargs) + rt_init_rv = DistributionalRV( + name="init_log_rt", distribution=dist.Normal(0, 0.2) + ) + init_rt, *_ = rt_init_rv.sample() + + return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs) rt_proc = MyRt() diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index 4bdb110b..7337dda9 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -106,7 +106,7 @@ class MyRt(metaclass.RandomVariable): def validate(self): pass - def sample(self, n_steps: int, **kwargs) -> tuple: + def sample(self, n: int, **kwargs) -> tuple: # Standard deviation of the random walk sd_rt, *_ = self.sd_rv() @@ -115,15 +115,14 @@ class MyRt(metaclass.RandomVariable): name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value) ) - init_rv = metaclass.DistributionalRV( + rt_init_rv = metaclass.DistributionalRV( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) # Random walk process - base_rv = process.SimpleRandomWalkProcess( + base_rv = process.RandomWalk( name="log_rt", step_rv=step_rv, - init_rv=init_rv, ) # Transforming the random walk to the Rt scale @@ -132,8 +131,9 @@ class MyRt(metaclass.RandomVariable): base_rv=base_rv, transforms=transformation.ExpTransform(), ) + init_rt, *_ = rt_init_rv.sample() - return rt_rv(n_steps=n_steps, **kwargs) + return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs) rtproc = MyRt( diff --git a/docs/source/tutorials/day_of_the_week.rst b/docs/source/tutorials/day_of_the_week.rst index ef1ea658..ddcaccbd 100644 --- a/docs/source/tutorials/day_of_the_week.rst +++ b/docs/source/tutorials/day_of_the_week.rst @@ -2,4 +2,4 @@ .. Please do not edit this file directly. .. This file is just a placeholder. .. For the source file, see: -.. +.. diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 5e854a4f..14615485 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -28,8 +28,12 @@ import numpyro.distributions as dist from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import InfectionsWithFeedback from pyrenew.model import RtInfectionsRenewalModel -from pyrenew.process import SimpleRandomWalkProcess -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable +from pyrenew.process import RandomWalk +from pyrenew.metaclass import ( + RandomVariable, + DistributionalRV, + TransformedRandomVariable, +) from pyrenew.latent import ( InfectionInitializationProcess, InitializeInfectionsExponentialGrowth, @@ -62,19 +66,31 @@ latent_infections = InfectionsWithFeedback( infection_feedback_pmf=gen_int, ) -rt = TransformedRandomVariable( - "Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", distribution=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( + +class MyRt(RandomVariable): + + def validate(self): + pass + + def sample(self, n: int, **kwargs) -> tuple: + sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) + + rt_rv = TransformedRandomVariable( + name="log_rt_random_walk", + base_rv=RandomWalk( + name="log_rt", + step_rv=DistributionalRV( + name="rw_step_rv", distribution=dist.Normal(0, 0.025) + ), + ), + transforms=t.ExpTransform(), + ) + rt_init_rv = DistributionalRV( name="init_log_rt", distribution=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), -) + ) + init_rt, *_ = rt_init_rv.sample() + + return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs) ``` With all the components defined, we can build the model: @@ -85,7 +101,7 @@ model0 = RtInfectionsRenewalModel( gen_int_rv=gen_int, I0_rv=I0, latent_infections_rv=latent_infections, - Rt_process_rv=rt, + Rt_process_rv=MyRt(), infection_obs_process_rv=None, ) ``` @@ -209,10 +225,13 @@ class InfFeedback(RandomVariable): inf_feedback_strength, *_ = self.infection_feedback_strength( **kwargs, ) + + inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength.value) + inf_feedback_strength = au.pad_x_to_match_y( - x=inf_feedback_strength.value, + x=inf_feedback_strength, y=Rt, - fill_value=inf_feedback_strength.value[0], + fill_value=inf_feedback_strength[0], ) # Sampling inf feedback and adjusting the shape @@ -260,7 +279,7 @@ model1 = RtInfectionsRenewalModel( gen_int_rv=gen_int, I0_rv=I0, latent_infections_rv=latent_infections2, - Rt_process_rv=rt, + Rt_process_rv=MyRt(), infection_obs_process_rv=None, ) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 0889782f..b43f41d4 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -186,37 +186,32 @@ gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int) class MyRt(metaclass.RandomVariable): - def __init__(self, sd_rv): - self.sd_rv = sd_rv def validate(self): pass - def sample(self, n_steps: int, **kwargs) -> tuple: - sd_rt, *_ = self.sd_rv() + def sample(self, n: int, **kwargs) -> tuple: + sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) rt_rv = metaclass.TransformedRandomVariable( - "Rt_rv", - base_rv=process.SimpleRandomWalkProcess( + name="log_rt_random_walk", + base_rv=process.RandomWalk( name="log_rt", step_rv=metaclass.DistributionalRV( - name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value) - ), - init_rv=metaclass.DistributionalRV( - name="init_log_rt", distribution=dist.Normal(0, 0.2) + name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=transformation.ExpTransform(), ) + rt_init_rv = metaclass.DistributionalRV( + name="init_log_rt", distribution=dist.Normal(0, 0.2) + ) + init_rt, *_ = rt_init_rv.sample() - return rt_rv.sample(n_steps=n_steps, **kwargs) + return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs) -rtproc = MyRt( - metaclass.DistributionalRV( - name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025) - ) -) +rtproc = MyRt() # The observation model diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 7586776b..bfe3e30d 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -24,7 +24,7 @@ from pyrenew import process, deterministic ```{python} # The random process for Rt -rt_proc = process.RtWeeklyDiffProcess( +rt_proc = process.RtWeeklyDiffARProcess( name="rt_weekly_diff", offset=0, log_rt_rv=deterministic.DeterministicVariable( @@ -57,7 +57,7 @@ for i in range(0, 30, 7): plt.show() ``` -The implementation of the `RtWeeklyDiffProcess` (which is an instance of `RtPeriodicDiffProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven. +The implementation of the `RtWeeklyDiffARProcess` (which is an instance of `RtPeriodicDiffARProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven. ## Repeated sequences (tiling) diff --git a/pyproject.toml b/pyproject.toml index dfdae0a0..acccc353 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ sphinxcontrib-mermaid = "^0.9.2" sphinx-autodoc-typehints = "^2.1.0" sphinx-book-theme = "^1.1.2" ipywidgets = "^8.1.3" +jupyter = "^1.0.0" [tool.poetry.group.test] optional = true @@ -49,6 +50,7 @@ optional = true pytest = "^8.3.2" pytest-cov = "^5.0.0" pytest-mpl = "^0.17.0" +scipy = "^1.14.1" [tool.numpydoc_validation] checks = [ @@ -81,4 +83,4 @@ build-backend = "poetry.core.masonry.api" known_first_party = ["pyrenew", "test"] [tool.deptry.per_rule_ignores] -DEP004 = ["pytest"] +DEP004 = ["pytest", "scipy"] diff --git a/src/pyrenew/deterministic/deterministic.py b/src/pyrenew/deterministic/deterministic.py index 8a8d28d4..13970e69 100644 --- a/src/pyrenew/deterministic/deterministic.py +++ b/src/pyrenew/deterministic/deterministic.py @@ -3,7 +3,6 @@ from __future__ import annotations -import jax.numpy as jnp import numpyro from jax.typing import ArrayLike @@ -42,7 +41,7 @@ def __init__( """ self.name = name self.validate(value) - self.value = jnp.atleast_1d(value) + self.value = value self.set_timeseries(t_start, t_unit) return None diff --git a/src/pyrenew/latent/infection_initialization_method.py b/src/pyrenew/latent/infection_initialization_method.py index a16b4632..537c2ca5 100644 --- a/src/pyrenew/latent/infection_initialization_method.py +++ b/src/pyrenew/latent/infection_initialization_method.py @@ -87,6 +87,7 @@ def initialize_infections(self, I_pre_init: ArrayLike): ArrayLike An array of length ``n_timepoints`` with the number of initialized infections at each time point. """ + I_pre_init = jnp.atleast_1d(I_pre_init) if self.n_timepoints < I_pre_init.size: raise ValueError( "I_pre_init must be no longer than n_timepoints. " @@ -105,20 +106,23 @@ def initialize_infections(self, I_pre_init: ArrayLike): Parameters ---------- I_pre_init : ArrayLike - An array with the same length as ``n_timepoints`` to be used as the initial infections. + An array with the same length as ``n_timepoints`` to be + used as the initial infections. Returns ------- ArrayLike - An array of length ``n_timepoints`` with the number of initialized infections at each time point. + An array of length ``n_timepoints`` with the number of + initialized infections at each time point. """ + I_pre_init = jnp.array(I_pre_init) if I_pre_init.size != self.n_timepoints: raise ValueError( "I_pre_init must have the same size as n_timepoints. " f"Got I_pre_init of size {I_pre_init.size} " f"and n_timepoints of size {self.n_timepoints}." ) - return jnp.array(I_pre_init) + return I_pre_init class InitializeInfectionsExponentialGrowth(InfectionInitializationMethod): @@ -173,11 +177,12 @@ def initialize_infections(self, I_pre_init: ArrayLike): ArrayLike An array of length ``n_timepoints`` with the number of initialized infections at each time point. """ + I_pre_init = jnp.array(I_pre_init) if I_pre_init.size != 1: raise ValueError( f"I_pre_init must be an array of size 1. Got size {I_pre_init.size}." ) - rate = self.rate()[0].value + rate = jnp.array(self.rate()[0].value) if rate.size != 1: raise ValueError( f"rate must be an array of size 1. Got size {rate.size}." diff --git a/src/pyrenew/latent/infection_initialization_process.py b/src/pyrenew/latent/infection_initialization_process.py index eefc5314..cb52698d 100644 --- a/src/pyrenew/latent/infection_initialization_process.py +++ b/src/pyrenew/latent/infection_initialization_process.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 -import numpyro - from pyrenew.latent.infection_initialization_method import ( InfectionInitializationMethod, ) @@ -97,7 +95,6 @@ def sample(self) -> tuple: infection_initialization = self.infection_init_method( I_pre_init.value, ) - numpyro.deterministic(self.name, infection_initialization) return ( SampledValue( diff --git a/src/pyrenew/latent/infectionswithfeedback.py b/src/pyrenew/latent/infectionswithfeedback.py index ad5214d1..f3d8ec8d 100644 --- a/src/pyrenew/latent/infectionswithfeedback.py +++ b/src/pyrenew/latent/infectionswithfeedback.py @@ -160,10 +160,11 @@ def sample( I0 = I0[-gen_int_rev.size :] # Sampling inf feedback strength - inf_feedback_strength = self.infection_feedback_strength( - **kwargs, - )[0].value - + inf_feedback_strength = jnp.atleast_1d( + self.infection_feedback_strength( + **kwargs, + )[0].value + ) # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.size == 1: inf_feedback_strength = au.pad_x_to_match_y( @@ -173,8 +174,9 @@ def sample( ) elif inf_feedback_strength.size != Rt.size: raise ValueError( - "Infection feedback strength must be of size 1 or the same " - f"size as the reproduction number. Got {inf_feedback_strength.size} " + "Infection feedback strength must be of size 1 " + "or the same size as the reproduction number array. " + f"Got {inf_feedback_strength.size} " f"and {Rt.size} respectively." ) diff --git a/src/pyrenew/metaclass.py b/src/pyrenew/metaclass.py index b5be0a54..424ffead 100644 --- a/src/pyrenew/metaclass.py +++ b/src/pyrenew/metaclass.py @@ -5,10 +5,9 @@ """ from abc import ABCMeta, abstractmethod -from typing import Callable, NamedTuple, get_type_hints +from typing import Callable, NamedTuple, Self, get_type_hints import jax -import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import numpy as np @@ -290,6 +289,7 @@ def __init__( name: str, distribution_constructor: Callable, reparam: Reparam = None, + expand_by_shape: tuple = None, ) -> None: """ Default constructor for DynamicDistributionalRV. @@ -305,6 +305,11 @@ def __init__( If not None, reparameterize sampling from the distribution according to the given numpyro reparameterizer + expand_by_shape : tuple, optional + If not None, call :meth:`expand_by()` on the + underlying distribution once it is instianted + with the given `expand_by_shape`. + Default None. Returns ------- @@ -318,6 +323,12 @@ def __init__( self.reparam_dict = {self.name: reparam} else: self.reparam_dict = {} + if not (expand_by_shape is None or isinstance(expand_by_shape, tuple)): + raise ValueError( + "expand_by_shape must be a tuple or be None ", + f"Got {type(expand_by_shape)}", + ) + self.expand_by_shape = expand_by_shape return None @@ -362,7 +373,7 @@ def sample( Positional arguments passed to self.distribution_constructor obs : ArrayLike, optional Observations passed as the `obs` argument to - :fun:`numpyro.sample()`. Default `None`. + :meth:`numpyro.sample()`. Default `None`. **kwargs : dict, optional Keyword arguments passed to self.distribution_constructor @@ -371,20 +382,52 @@ def sample( SampledValue Containing a sample from the distribution. """ + distribution = self.distribution_constructor(*args, **kwargs) + if self.expand_by_shape is not None: + distribution = distribution.expand_by(self.expand_by_shape) with numpyro.handlers.reparam(config=self.reparam_dict): sample = numpyro.sample( name=self.name, - fn=self.distribution_constructor(*args, **kwargs), + fn=distribution, obs=obs, ) return ( SampledValue( - jnp.atleast_1d(sample), + sample, t_start=self.t_start, t_unit=self.t_unit, ), ) + def expand_by(self, sample_shape) -> Self: + """ + Expand the distribution by a given + shape_shape, if possible. Returns a + new DynamicDistributionalRV whose underlying + distribution will be expanded by the given shape + at sample() time. + + Parameters + ---------- + sample_shape : tuple + Sample shape by which to expand the distribution. + Passed to the expand_by() method of + :class:`numpyro.distributions.Distribution` + after the distribution is instantiated. + + Returns + ------- + DynamicDistributionalRV + Whose underlying distribution will be expanded by + the given sample shape at sampling time. + """ + return DynamicDistributionalRV( + name=self.name, + distribution_constructor=self.distribution_constructor, + reparam=self.reparam_dict.get(self.name, None), + expand_by_shape=sample_shape, + ) + class StaticDistributionalRV(RandomVariable): """ @@ -432,7 +475,7 @@ def __init__( @staticmethod def validate(distribution: any) -> None: """ - Validation of the distribution to be implemented in subclasses. + Validation of the distribution. """ if not isinstance(distribution, numpyro.distributions.Distribution): raise ValueError( @@ -455,7 +498,7 @@ def sample( ---------- obs : ArrayLike, optional Observations passed as the `obs` argument to - :fun:`numpyro.sample()`. Default `None`. + :meth:`numpyro.sample()`. Default `None`. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. @@ -473,12 +516,45 @@ def sample( ) return ( SampledValue( - jnp.atleast_1d(sample), + sample, t_start=self.t_start, t_unit=self.t_unit, ), ) + def expand_by(self, sample_shape) -> Self: + """ + Expand the distribution by the given sample_shape, + if possible. Returns a new StaticDistributionalRV + whose underlying distribution has been expanded by + the given sample_shape via + :meth:`~numpyro.distributions.Distribution.expand_by()` + + Parameters + ---------- + sample_shape : tuple + Sample shape for the expansion. Passed to the + :meth:`expand_by()` method of + :class:`numpyro.distributions.Distribution`. + + Returns + ------- + StaticDistributionalRV + Whose underlying distribution has been expanded by + the given sample shape. + """ + if not isinstance(sample_shape, tuple): + raise ValueError( + "sample_shape for expand()-ing " + "a DistributionalRV must be a " + f"tuple. Got {type(sample_shape)}" + ) + return StaticDistributionalRV( + name=self.name, + distribution=self.distribution.expand_by(sample_shape), + reparam=self.reparam_dict.get(self.name, None), + ) + def DistributionalRV( name: str, @@ -561,8 +637,8 @@ def sample( Parameters ---------- **kwargs : dict, optional - Additional keyword arguments passed through to internal `sample()` - calls, should there be any. + Additional keyword arguments passed through to internal + `sample()` calls, should there be any. Returns ------- diff --git a/src/pyrenew/model/rtinfectionsrenewalmodel.py b/src/pyrenew/model/rtinfectionsrenewalmodel.py index 410c628d..62cda6ff 100644 --- a/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -9,7 +9,6 @@ import numpyro from numpy.typing import ArrayLike -import pyrenew.arrayutils as au from pyrenew.deterministic import NullObservation from pyrenew.metaclass import ( Model, @@ -194,7 +193,7 @@ def sample( # Sampling from Rt (possibly with a given Rt, depending on # the Rt_process (RandomVariable) object.) Rt, *_ = self.Rt_process_rv( - n_steps=n_timepoints, + n=n_timepoints, **kwargs, ) @@ -220,29 +219,17 @@ def sample( **kwargs, ) - all_latent_infections = jnp.hstack( - [I0.value, post_initialization_latent_infections.value] + all_latent_infections = SampledValue( + jnp.hstack([I0.value, post_initialization_latent_infections.value]) ) - numpyro.deterministic("all_latent_infections", all_latent_infections) - - if observed_infections is not None: - observed_infections = au.pad_x_to_match_y( - observed_infections.value, - all_latent_infections, - jnp.nan, - pad_direction="start", - ) - - Rt = au.pad_x_to_match_y( - Rt.value, - all_latent_infections, - jnp.nan, - pad_direction="start", + numpyro.deterministic( + "all_latent_infections", all_latent_infections.value ) - numpyro.deterministic("Rt", Rt) + + numpyro.deterministic("Rt", Rt.value) return RtInfectionsRenewalSample( - Rt=SampledValue(Rt), - latent_infections=SampledValue(all_latent_infections), - observed_infections=SampledValue(observed_infections), + Rt=Rt, + latent_infections=all_latent_infections, + observed_infections=observed_infections, ) diff --git a/src/pyrenew/process/__init__.py b/src/pyrenew/process/__init__.py index bad08343..638ea45d 100644 --- a/src/pyrenew/process/__init__.py +++ b/src/pyrenew/process/__init__.py @@ -3,20 +3,27 @@ # numpydoc ignore=GL08 from pyrenew.process.ar import ARProcess -from pyrenew.process.firstdifferencear import FirstDifferenceARProcess +from pyrenew.process.differencedprocess import DifferencedProcess +from pyrenew.process.iidrandomsequence import ( + IIDRandomSequence, + StandardNormalSequence, +) from pyrenew.process.periodiceffect import DayOfWeekEffect, PeriodicEffect -from pyrenew.process.rtperiodicdiff import ( - RtPeriodicDiffProcess, - RtWeeklyDiffProcess, +from pyrenew.process.randomwalk import RandomWalk, StandardNormalRandomWalk +from pyrenew.process.rtperiodicdiffar import ( + RtPeriodicDiffARProcess, + RtWeeklyDiffARProcess, ) -from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess __all__ = [ + "IIDRandomSequence", + "StandardNormalSequence", "ARProcess", - "FirstDifferenceARProcess", - "SimpleRandomWalkProcess", - "RtPeriodicDiffProcess", - "RtWeeklyDiffProcess", + "DifferencedProcess", + "RandomWalk", + "StandardNormalRandomWalk", "PeriodicEffect", "DayOfWeekEffect", + "RtPeriodicDiffARProcess", + "RtWeeklyDiffARProcess", ] diff --git a/src/pyrenew/process/ar.py b/src/pyrenew/process/ar.py index 147e4097..feb10511 100644 --- a/src/pyrenew/process/ar.py +++ b/src/pyrenew/process/ar.py @@ -4,55 +4,38 @@ from __future__ import annotations import jax.numpy as jnp -import numpyro -import numpyro.distributions as dist -from jax import lax from jax.typing import ArrayLike +from numpyro.contrib.control_flow import scan from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.process.iidrandomsequence import StandardNormalSequence class ARProcess(RandomVariable): """ - Object to represent - an AR(p) process in - Numpyro + RandomVariable representing an + an AR(p) process. """ - def __init__( - self, - name: str, - mean: float, - autoreg: ArrayLike, - noise_sd: float, - ) -> None: + def __init__(self, noise_rv_name: str, *args, **kwargs) -> None: """ - Default constructor + Default constructor. Parameters ---------- - name : str - Name of the parameter passed to numpyro.sample. - mean: float - Mean parameter. - autoreg : ArrayLike - Model parameters. The shape determines the order. - noise_sd : float - Standard error for the noise component. - - Returns - ------- - None + noise_rv_name : str + A name for the internal RandomVariable + holding the process noise. """ - self.name = name - self.mean = mean - self.autoreg = autoreg - self.noise_sd = noise_sd + super().__init__(*args, **kwargs) + self.noise_rv_ = StandardNormalSequence(element_rv_name=noise_rv_name) def sample( self, - duration: int, - inits: ArrayLike = None, + n: int, + autoreg: ArrayLike, + init_vals: ArrayLike, + noise_sd: float | ArrayLike, **kwargs, ) -> tuple: """ @@ -60,41 +43,74 @@ def sample( Parameters ---------- - duration: int + n: int Length of the sequence. - inits : ArrayLike, optional - Initial points, if None, then these are sampled. - Defaults to None. + autoreg: ArrayLike + Autoregressive coefficients. + The length of the array's first + dimension determines the order :math`p` + of the AR process. + init_vals : ArrayLike + Array of initial values. Must have the + same first dimension size as the order. + noise_sd : float | ArrayLike + Scalar giving the s.d. of the AR + process Normal noise, which by + definition has mean 0. **kwargs : dict, optional - Additional keyword arguments passed through to internal sample() - calls, should there be any. + Additional keyword arguments passed to + self.noise_rv_.sample() Returns ------- tuple - With a single array of shape (duration,). + With a single SampledValue containing an + array of shape (n,). """ - order = self.autoreg.shape[0] - if inits is None: - inits = numpyro.sample( - self.name + "_sampled_inits", - dist.Normal(0, self.noise_sd).expand((order,)), + noise_sd_arr = jnp.atleast_1d(jnp.array(noise_sd)) + if not noise_sd_arr.shape == (1,): + raise ValueError("noise_sd must be a scalar. " f"Got {noise_sd}") + autoreg = jnp.atleast_1d(jnp.array(autoreg)) + init_vals = jnp.atleast_1d(jnp.array(init_vals)) + + if not autoreg.ndim == 1: + raise ValueError( + "Array of autoregressive coefficients " + "must be no more than 1 dimension", + f"Got {autoreg.ndim}", + ) + if not init_vals.ndim == 1: + raise ValueError( + "Array of initial values must be " "no more than 1 dimension", + f"Got {init_vals.ndim}", + ) + order = autoreg.size + if not init_vals.size == order: + raise ValueError( + "Array of initial values must be " + "be the same size as the order of " + "the autoregressive process, " + "which is determined by the number " + "of autoregressive coefficients " + "provided. Got {init_vals.size} " + "initial values for a process of " + f"order {order}" ) - def _ar_scanner(carry, next): # numpydoc ignore=GL08 - new_term = (jnp.dot(self.autoreg, carry) + next).flatten() - new_carry = jnp.hstack([new_term, carry[: (order - 1)]]) - return new_carry, new_term + raw_noise, *_ = self.noise_rv_(n=n, **kwargs) + noise = noise_sd_arr * raw_noise.value - noise = numpyro.sample( - self.name + "_noise", - dist.Normal(0, self.noise_sd).expand((duration - inits.size,)), - ) + def transition(recent_vals, next_noise): # numpydoc ignore=GL08 + new_term = jnp.dot(autoreg, recent_vals) + next_noise + new_recent_vals = jnp.hstack( + [new_term, recent_vals[: (order - 1)]] + ) + return new_recent_vals, new_term - last, ts = lax.scan(_ar_scanner, inits - self.mean, noise) + last, ts = scan(transition, init_vals, noise) return ( SampledValue( - jnp.hstack([inits, self.mean + ts.flatten()]), + jnp.hstack([init_vals, ts]), t_start=self.t_start, t_unit=self.t_unit, ), diff --git a/src/pyrenew/process/differencedprocess.py b/src/pyrenew/process/differencedprocess.py new file mode 100644 index 00000000..b19173f2 --- /dev/null +++ b/src/pyrenew/process/differencedprocess.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +# numpydoc ignore=GL08 + +from __future__ import annotations + +import jax.numpy as jnp +from jax.typing import ArrayLike +from numpyro.contrib.control_flow import scan + +from pyrenew.metaclass import RandomVariable, SampledValue + + +class DifferencedProcess(RandomVariable): + """ + Class for differenced stochastic process X(t), + constructed by placing a fundamental stochastic + process on the :math:`n^{th}` differences + (rates of change). See + https://otexts.com/fpp3/stationarity.html + for a discussion of differencing in the + context of discrete timeseries data. + """ + + def __init__( + self, + fundamental_process: RandomVariable, + differencing_order: int, + **kwargs, + ) -> None: + """ + Default constructor + + Parameters + ---------- + fundamental_process : RandomVariable + Stochastic process for the + differences. Should accept an + `n` argument specifying the number + of samples to draw. + differencing_order : int + How many fold-differencing the + the process represents. Must be + an integer greater than or + equal to 1. 1 represents a process + on the first differences (the rate + of change), 2 a process on the + 2nd differences (rate of change of + the rate of change), et cetera. + **kwargs : + Additional keyword arguments passed to + the parent class constructor. + + Returns + ------- + None + + Notes + ----- + The order of differencing is the discrete + analogue of the order of a derivative in single + variable calculus. A first difference (derivative) + represents a rate of change. A second difference + (derivative) represents the rate of change of that + rate of change, et cetera. + """ + self.assert_valid_differencing_order(differencing_order) + self.differencing_order = differencing_order + self.fundamental_process = fundamental_process + super().__init__(**kwargs) + + def integrate( + self, init_diff_vals: ArrayLike, highest_order_diff_vals: ArrayLike + ): + """ + Integrate (de-difference) the differenced process, + obtaining the process values :math:`X(t=0), X(t=1), ... X(t)` + from the :math:`n^{th}` differences and a set of + initial process / difference values + :math:`X(t=0), X^1(t=1), X^2(t=2), ... X^{(n-1)}(t=n-1)`, + where :math:`X^k(t)` is the value of the :math:`n^{th}` + difference at index :math:`t` of the process, + obtaining a sequence of length equal to the length of + the provided `highest_order_diff_vals` vector plus + the order of the process. + + Parameters + ---------- + init_diff_vals : ArrayLike + Values of + :math:`X(t=0), X^1(t=1), X^2(t=2) ... X^{(n-1)}(t=n-1)`. + + highest_order_diff_vals : ArrayLike + Array of differences at the highest order of + differencing, i.e. the order of the overall process, + starting with :math:`X^{n}(t=n)` + + Returns + ------- + The integrated (de-differenced) sequence of values, + of length n_diffs + order, where n_diffs is the + number of highest_order_diff_vals and order is the + order of the process. + """ + init_arr = jnp.atleast_1d(init_diff_vals) + diff_arr = jnp.atleast_1d(highest_order_diff_vals) + if not init_arr.ndim == 1: + raise ValueError( + "init_diff_vals must be 1-dimensional " + "array or a scalar. " + f"Got {init_diff_vals}" + ) + if not diff_arr.ndim == 1: + raise ValueError( + "highest_order_diff_vals must be a " + "1-dimensional array or a scalar " + f"Got {highest_order_diff_vals}" + ) + n_inits = init_arr.size + if not n_inits == self.differencing_order: + raise ValueError( + "Must have exactly as many " + "initial difference values as " + "the differencing order, given " + "in the sequence :math:`X(t=0), X^1(t=1),` " + "et cetera. " + f"Got {n_inits} values " + "for a process of order " + f"{self.differencing_order}" + ) + + def _integrate_one_step(diffs, scanned): + # numpydoc ignore=GL08 + order, init = scanned + new_diffs = jnp.cumsum(diffs.at[order].set(init)) + return (new_diffs, None) + + integrated, _ = scan( + _integrate_one_step, + init=jnp.pad(diff_arr, (self.differencing_order, 0)), + xs=( + jnp.flip(jnp.arange(self.differencing_order)), + jnp.flip(init_arr), + ), + ) + + return integrated + + def sample( + self, + init_vals: ArrayLike, + n: int, + *args, + fundamental_process_init_vals: ArrayLike = None, + **kwargs, + ) -> tuple: + """ + Sample from the process + + Parameters + ---------- + init_vals : ArrayLike + initial values for the :math:`0^{th}` through + :math:`(n-1)^{st}` differences, passed as the + ``init_diff_vals`` argument to + :meth:`DifferencedProcess.integrate()` + + n : int + Number of values to sample. Will sample + ``n - self.differencing_order`` values from + :meth:`self.fundamental_process` to ensure + that the de-differenced output is of length + ``n``. + + *args : + Additional positional arguments passed to + :meth:`self.fundamental_process.sample()` + + fundamental_process_init_vals : ArrayLike + Initial values for the fundamental process. + Passed as the ``init_vals`` keyword argument + to :meth:`self.fundamental_process.sample()`. + + **kwargs : dict, optional + Keyword arguments passed to + :meth:`self.fundamental_process.sample()`. + + Returns + ------- + SampledValue + Whose value entry is a single array representing the + undifferenced timeseries + """ + if not isinstance(n, int): + raise ValueError("n must be an integer. " f"Got {type(n)}") + if n < 1: + raise ValueError("n must be positive. " f"Got {n}") + + n_diffs = n - self.differencing_order + if n_diffs > 0: + diff_samp, *_ = self.fundamental_process.sample( + *args, + n=n_diffs, + init_vals=fundamental_process_init_vals, + **kwargs, + ) + diffs = diff_samp.value + else: + diffs = jnp.array([]) + integrated_ts = self.integrate(init_vals, diffs)[:n] + return ( + SampledValue( + value=integrated_ts, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) + + @staticmethod + def validate(): + """ + Validates input parameters, implementation pending. + """ + return None + + @staticmethod + def assert_valid_differencing_order(differencing_order: any): + """ + To be valid, a differencing order must + be an integer and must be strictly positive. + This function raises a value error if its + argument is not a valid differencing order. + + Parameter + --------- + differcing_order : any + Potential differencing order to validate. + + Returns + ------- + None or raises a ValueError + """ + if not isinstance(differencing_order, int): + raise ValueError( + "differencing_order must be an integer. " + f"got type {type(differencing_order)} " + f"and value {differencing_order}" + ) + if not differencing_order >= 1: + raise ValueError( + "differencing_order must be an integer " + "greater than or equal to 1. Got " + f"{differencing_order}" + ) diff --git a/src/pyrenew/process/firstdifferencear.py b/src/pyrenew/process/firstdifferencear.py deleted file mode 100644 index 63a1e826..00000000 --- a/src/pyrenew/process/firstdifferencear.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- -# numpydoc ignore=GL08 - -from __future__ import annotations - -import jax.numpy as jnp -from jax.typing import ArrayLike - -from pyrenew.metaclass import RandomVariable, SampledValue -from pyrenew.process import ARProcess - - -class FirstDifferenceARProcess(RandomVariable): - """ - Class for a stochastic process - with an AR(1) process on the first - differences (i.e. the rate of change). - """ - - def __init__( - self, - name: str, - autoreg: ArrayLike, - noise_sd: float, - ) -> None: - """ - Default constructor - - Parameters - ---------- - name : str - Passed to ARProcess() - autoreg : ArrayLike - Process parameters pyrenew.processesARprocess. - noise_sd : float - Error passed to pyrenew.processes.ARProcess. - - Returns - ------- - None - """ - self.rate_of_change_proc = ARProcess( - "arprocess", 0, jnp.array([autoreg]), noise_sd - ) - self.name = name - - def sample( - self, - duration: int, - init_val: ArrayLike = None, - init_rate_of_change: ArrayLike = None, - **kwargs, - ) -> tuple: - """ - Sample from the process - - Parameters - ---------- - duration : int - Passed to ARProcess() - init_val : ArrayLike, optional - Starting point of the AR process, by default None. - init_rate_of_change : ArrayLike, optional - Passed to ARProcess.sample, by default None. - **kwargs : dict, optional - Additional keyword arguments passed through to internal sample() - calls, should there be any. - - Returns - ------- - tuple - With a single array of shape (duration,). - """ - rates_of_change, *_ = self.rate_of_change_proc.sample( - name=self.name + "_rate_of_change", - duration=duration, - inits=jnp.atleast_1d(init_rate_of_change), - ) - return ( - SampledValue( - init_val + jnp.cumsum(rates_of_change.value.flatten()), - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) - - @staticmethod - def validate(): - """ - Validates input parameters, implementation pending. - """ - return None diff --git a/src/pyrenew/process/iidrandomsequence.py b/src/pyrenew/process/iidrandomsequence.py new file mode 100644 index 00000000..2f868ada --- /dev/null +++ b/src/pyrenew/process/iidrandomsequence.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# numpydoc ignore=GL08 + +import numpyro.distributions as dist +from numpyro.contrib.control_flow import scan + +from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue + + +class IIDRandomSequence(RandomVariable): + """ + Class for constructing random sequence of + independent and identically distributed elements + given an arbitrary RandomVariable representing + those elements. + """ + + def __init__( + self, + element_rv: RandomVariable, + **kwargs, + ) -> None: + """ + Default constructor + + Parameters + ---------- + name : str + A name for the random variable, used to + name sites within it in :meth:`numpyro.sample()` + calls. + element_rv : RandomVariable + RandomVariable representing a single element + in the sequence. + + Returns + ------- + None + """ + super().__init__(**kwargs) + self.element_rv = element_rv + + def sample( + self, n: int, *args, vectorize: bool = False, **kwargs + ) -> tuple: + """ + Sample an IID random sequence. + + Parameters + ---------- + n : int + Length of the sequence to sample. + + *args : + Additional positional arguments passed + to self.element_rv.sample() + + vectorize: bool + Sample vectorized? If True, use + :meth:`RandomVariable.expand_by()`, + whenever available, and fall back on + :meth:`numpyro.contrib.control_flow.scan`. + If False, always use :meth:`scan()`. + Default False. + + **kwargs: + Additional keyword arguments passed to + self.element_rv.sample(). + + Returns + ------- + tuple[SampledValue] + Whose value is an array of `n` + samples from `self.distribution` + """ + + if vectorize and hasattr(self.element_rv, "expand_by"): + result, *_ = self.element_rv.expand_by((n,)).sample( + *args, **kwargs + ) + result = result.value + else: + + def transition(_carry, _x): + # numpydoc ignore=GL08 + el, *_ = self.element_rv.sample(*args, **kwargs) + return None, el.value + + _, result = scan( + transition, + xs=None, + init=None, + length=n, + ) + + return ( + SampledValue( + result, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) + + @staticmethod + def validate(): + """ + Validates input parameters, implementation pending. + """ + super().validate() + return None + + +class StandardNormalSequence(IIDRandomSequence): + """ + Class for a sequence of IID standard Normal + (mean = 0, sd = 1) random variables. + """ + + def __init__( + self, + element_rv_name: str, + **kwargs, + ): + """ + Default constructor + + Parameters + ---------- + name : str + see :class:`IIDRandomSequence`. + element_rv_name: str + Name for the internal element_rv, here a + DistributionalRV encoding a + standard Normal (mean = 0, sd = 1) + distribution. + + Returns + ------- + None + """ + super().__init__( + element_rv=DistributionalRV( + name=element_rv_name, distribution=dist.Normal(0, 1) + ), + ) diff --git a/src/pyrenew/process/randomwalk.py b/src/pyrenew/process/randomwalk.py new file mode 100644 index 00000000..a9fa472e --- /dev/null +++ b/src/pyrenew/process/randomwalk.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# numpydoc ignore=GL08 + +import numpyro.distributions as dist + +from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.process.differencedprocess import DifferencedProcess +from pyrenew.process.iidrandomsequence import IIDRandomSequence + + +class RandomWalk(DifferencedProcess): + """ + Class for a Markovian + random walk with an arbitrary + step distribution, implemented + via DifferencedProcess and + IIDRandomSequence + """ + + def __init__( + self, + step_rv: RandomVariable, + **kwargs, + ): + """ + Default constructor + + Parameters + ---------- + name : str + A name for the random variable, used to + name sites within it in :meth:`numpyro.sample()` + calls. + + step_rv : RandomVariable + RandomVariable representing a single step + (difference) in the random walk. + + **kwargs : + Additional keyword arguments passed to the parent + class constructor. + + Returns + ------- + None + """ + super().__init__( + fundamental_process=IIDRandomSequence(element_rv=step_rv), + differencing_order=1, + **kwargs, + ) + + +class StandardNormalRandomWalk(RandomWalk): + """ + A random walk with standard Normal + (mean = 0, standard deviation = 1) + steps, implmemented via the base + RandomWalk class. + """ + + def __init__( + self, + step_rv_name: str, + **kwargs, + ): + """ + Default constructor + Parameters + ---------- + step_rv_name : + Name for the DistributionalRV + from which the Normal(0, 1) + steps are sampled. + **kwargs: + Additional keyword arguments passed + to the parent class constructor. + Return + ------ + None + """ + super().__init__( + step_rv=DistributionalRV( + name=step_rv_name, distribution=dist.Normal(0.0, 1.0) + ), + **kwargs, + ) diff --git a/src/pyrenew/process/rtperiodicdiff.py b/src/pyrenew/process/rtperiodicdiffar.py similarity index 73% rename from src/pyrenew/process/rtperiodicdiff.py rename to src/pyrenew/process/rtperiodicdiffar.py index d06e4e31..9186b9ef 100644 --- a/src/pyrenew/process/rtperiodicdiff.py +++ b/src/pyrenew/process/rtperiodicdiffar.py @@ -2,7 +2,6 @@ from typing import NamedTuple import jax.numpy as jnp -from jax.typing import ArrayLike import pyrenew.arrayutils as au from pyrenew.metaclass import ( @@ -10,12 +9,13 @@ SampledValue, _assert_sample_and_rtype, ) -from pyrenew.process.firstdifferencear import FirstDifferenceARProcess +from pyrenew.process import ARProcess, DifferencedProcess -class RtPeriodicDiffProcessSample(NamedTuple): +class RtPeriodicDiffARProcessSample(NamedTuple): """ - A container for holding the output from `process.RtPeriodicDiffProcess()`. + A container for holding the output from + `process.RtPeriodicDiffARProcess()`. Attributes ---------- @@ -26,17 +26,18 @@ class RtPeriodicDiffProcessSample(NamedTuple): rt: SampledValue | None = None def __repr__(self): - return f"RtPeriodicDiffProcessSample(rt={self.rt})" + return f"RtPeriodicDiffARProcessSample(rt={self.rt})" -class RtPeriodicDiffProcess(RandomVariable): +class RtPeriodicDiffARProcess(RandomVariable): r""" - Periodic Rt with autoregressive difference. + Periodic Rt with autoregressive first differences Notes ----- - This class samples a periodic Rt with autoregressive difference. The - mathematical formulation is given by: + This class samples a periodic reproduction number R(t) + by placing an AR(1) process + on the first differences in log[R(t)]. Formally: .. math:: \log[\mathcal{R}^\mathrm{u}(t_3)] \sim \mathrm{Normal}\left(\log[\mathcal{R}^\mathrm{u}(t_2)] \ @@ -56,9 +57,10 @@ def __init__( log_rt_rv: RandomVariable, autoreg_rv: RandomVariable, periodic_diff_sd_rv: RandomVariable, + ar_process_suffix: str = "_first_diff_ar_process_noise", ) -> None: """ - Default constructor for RtPeriodicDiffProcess class. + Default constructor for RtPeriodicDiffARProcess class. Parameters ---------- @@ -73,6 +75,11 @@ def __init__( Autoregressive parameter. periodic_diff_sd_rv : RandomVariable Standard deviation of the noise. + ar_process_suffix : str + Suffix to append to the :class:`RandomVariable`'s ``name`` + when naming the :class:`RandomVariable` that represents + the underlying AR process noise. + Default "_first_diff_ar_process_noise". Returns ------- @@ -91,6 +98,12 @@ def __init__( self.log_rt_rv = log_rt_rv self.autoreg_rv = autoreg_rv self.periodic_diff_sd_rv = periodic_diff_sd_rv + self.ar_diff = DifferencedProcess( + fundamental_process=ARProcess( + noise_rv_name=f"{name}{ar_process_suffix}" + ), + differencing_order=1, + ) return None @@ -123,36 +136,11 @@ def validate( return None - @staticmethod - def autoreg_process( - dat: ArrayLike, sigma: float - ) -> tuple[ArrayLike, float]: - """ - Scan function for the autoregressive process. - - Parameters - ---------- - dat : ArrayLike - Data array with three elements: log_rt0, log_rt1, and b. - sigma : float - Standard deviation of the noise. - - Returns - ------- - tuple - """ - - log_rt0, log_rt1, b = dat - - next_log_rt = log_rt1 + b * (log_rt1 - log_rt0) + sigma - - return jnp.hstack([log_rt1, next_log_rt, b]), next_log_rt - def sample( self, duration: int, **kwargs, - ) -> RtPeriodicDiffProcessSample: + ) -> RtPeriodicDiffARProcessSample: """ Samples the periodic Rt with autoregressive difference. @@ -166,7 +154,7 @@ def sample( Returns ------- - RtPeriodicDiffProcessSample + RtPeriodicDiffARProcessSample Named tuple with "rt". """ @@ -179,17 +167,21 @@ def sample( n_periods = (duration + self.period_size - 1) // self.period_size # Running the process - ar_diff = FirstDifferenceARProcess(self.name, autoreg=b, noise_sd=s_r) - log_rt = ar_diff.sample( - duration=n_periods, - init_val=log_rt_rv[1], - init_rate_of_change=log_rt_rv[1] - log_rt_rv[0], + + log_rt = self.ar_diff( + n=n_periods, + init_vals=jnp.array([log_rt_rv[0]]), + autoreg=b, + noise_sd=s_r, + fundamental_process_init_vals=jnp.array( + [log_rt_rv[1] - log_rt_rv[0]] + ), )[0] - return RtPeriodicDiffProcessSample( + return RtPeriodicDiffARProcessSample( rt=SampledValue( au.repeat_until_n( - data=jnp.exp(log_rt.value.flatten()), + data=jnp.exp(log_rt.value), n_timepoints=duration, offset=self.offset, period_size=self.period_size, @@ -200,9 +192,9 @@ def sample( ) -class RtWeeklyDiffProcess(RtPeriodicDiffProcess): +class RtWeeklyDiffARProcess(RtPeriodicDiffARProcess): """ - Weekly Rt with autoregressive difference. + Weekly Rt with autoregressive first differences. """ def __init__( @@ -214,7 +206,7 @@ def __init__( periodic_diff_sd_rv: RandomVariable, ) -> None: """ - Default constructor for RtWeeklyDiffProcess class. + Default constructor for RtWeeklyDiffARProcess class. Parameters ---------- diff --git a/src/pyrenew/process/simplerandomwalk.py b/src/pyrenew/process/simplerandomwalk.py deleted file mode 100644 index 7b358812..00000000 --- a/src/pyrenew/process/simplerandomwalk.py +++ /dev/null @@ -1,104 +0,0 @@ -# -*- coding: utf-8 -*- -# numpydoc ignore=GL08 - -import jax.numpy as jnp -from numpyro.contrib.control_flow import scan - -from pyrenew.metaclass import RandomVariable, SampledValue - - -class SimpleRandomWalkProcess(RandomVariable): - """ - Class for a Markovian - random walk with an a - step distribution - """ - - def __init__( - self, - name: str, - step_rv: RandomVariable, - init_rv: RandomVariable, - t_start: int = None, - t_unit: int = None, - ) -> None: - """ - Default constructor - - Parameters - ---------- - name : str - A name for the random variable, used to - name sites within it in :fun :`numpyro.sample()` - calls. - step_rv : RandomVariable - RandomVariable representing the step distribution. - init_rv : RandomVariable - RandomVariable representing the initial value of - the process - t_start : int - See :class:`RandomVariable` - t_unit : int - See :class:`RandomVariable` - - Returns - ------- - None - """ - self.name = name - self.step_rv = step_rv - self.init_rv = init_rv - self.t_start = t_start - self.t_unit = t_unit - - def sample( - self, - n_steps: int, - **kwargs, - ) -> tuple: - """ - Sample from the random walk. - - Parameters - ---------- - n_steps : int - Length of the walk to sample. - **kwargs : dict, optional - Additional keyword arguments passed through to internal sample() - calls, should there be any. - - Returns - ------- - tuple - With a single array of shape (n_steps,). - """ - - init, *_ = self.init_rv(**kwargs) - - def transition(x_prev, _): - # numpydoc ignore=GL08 - diff, *_ = self.step_rv(**kwargs) - x_curr = x_prev + diff.value - return x_curr, x_curr - - _, x = scan( - transition, - init=init.value, - xs=jnp.arange(n_steps - 1), - ) - - return ( - SampledValue( - jnp.hstack([init.value, x.flatten()]), - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) - - @staticmethod - def validate(): - """ - Validates input parameters, implementation pending. - """ - super().validate() - return None diff --git a/src/test/test_ar_process.py b/src/test/test_ar_process.py index 4dfe074d..b1df31af 100755 --- a/src/test/test_ar_process.py +++ b/src/test/test_ar_process.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import numpyro +import pytest from numpy.testing import assert_almost_equal from pyrenew.process import ARProcess @@ -12,19 +13,78 @@ def test_ar_can_be_sampled(): Check that an AR process can be initialized and sampled from """ - ar1 = ARProcess("arprocess", 5, jnp.array([0.95]), jnp.array([0.5])) + ar1 = ARProcess(noise_rv_name="ar1process_noise") with numpyro.handlers.seed(rng_seed=62): - # can sample with and without inits - ar1(duration=3532, inits=jnp.array([50.0])) - ar1(duration=5023) + # can sample + ar1( + n=3532, + init_vals=jnp.array([50.0]), + autoreg=jnp.array([0.95]), + noise_sd=0.5, + ) + + ar3 = ARProcess(noise_rv_name="ar3process_noise") - ar3 = ARProcess( - "arprocess", 5, jnp.array([0.05, 0.025, 0.025]), jnp.array([0.5]) - ) with numpyro.handlers.seed(rng_seed=62): - # can sample with and without inits - ar3(duration=1230) - ar3(duration=52, inits=jnp.array([50.0, 49.9, 48.2])) + # can sample + ar3( + n=1230, + init_vals=jnp.array([50.0, 49.9, 48.2]), + autoreg=jnp.array([0.05, 0.025, 0.025]), + noise_sd=0.5, + ) + ar3( + n=1230, + init_vals=jnp.array([50.0, 49.9, 48.2]), + autoreg=jnp.array([0.05, 0.025, 0.025]), + noise_sd=[0.25], + ) + ar3( + n=1230, + init_vals=jnp.array([50.0, 49.9, 48.2]), + autoreg=jnp.array([0.05, 0.025, 0.025]), + noise_sd=jnp.array([0.25]), + ) + + # vector valued noise raises + # error + with pytest.raises(ValueError, match="must be a scalar"): + ar3( + n=1230, + init_vals=jnp.array([50.0, 49.9, 48.2]), + autoreg=jnp.array([0.05, 0.025, 0.025]), + noise_sd=jnp.array([1.0, 2.0]), + ) + with pytest.raises(ValueError, match="must be a scalar"): + ar3( + n=1230, + init_vals=jnp.array([50.0, 49.9, 48.2]), + autoreg=jnp.array([0.05, 0.025, 0.025]), + noise_sd=[1.0, 2.0], + ) + + # bad dimensionality raises error + with pytest.raises(ValueError, match="Array of autoregressive"): + ar3( + n=1230, + init_vals=jnp.array([50.0, 49.9, 48.2]), + autoreg=jnp.array([[0.05, 0.025, 0.025]]), + noise_sd=0.5, + ) + with pytest.raises(ValueError, match="Array of initial"): + ar3( + n=1230, + init_vals=jnp.array([[50.0, 49.9, 48.2]]), + autoreg=jnp.array([0.05, 0.025, 0.025]), + noise_sd=0.5, + ) + with pytest.raises(ValueError, match="same size as the order"): + ar3( + n=1230, + init_vals=jnp.array([50.0, 49.9, 1, 1, 1]), + autoreg=jnp.array([0.05, 0.025, 0.025]), + noise_sd=0.5, + ) def test_ar_samples_correctly_distributed(): @@ -32,13 +92,17 @@ def test_ar_samples_correctly_distributed(): Check that AR processes have correctly- distributed steps. """ - ar_mean = 5 noise_sd = jnp.array([0.5]) ar_inits = jnp.array([25.0]) - ar1 = ARProcess("arprocess", ar_mean, jnp.array([0.75]), noise_sd) + ar = ARProcess("arprocess") with numpyro.handlers.seed(rng_seed=62): # check it regresses to mean # when started away from it - long_ts, *_ = ar1(duration=10000, inits=ar_inits) + long_ts, *_ = ar( + n=10000, + init_vals=ar_inits, + autoreg=jnp.array([0.75]), + noise_sd=noise_sd, + ) assert_almost_equal(long_ts.value[0], ar_inits) - assert jnp.abs(long_ts.value[-1] - ar_mean) < 4 * noise_sd + assert jnp.abs(long_ts.value[-1]) < 4 * noise_sd diff --git a/src/test/test_differenced_process.py b/src/test/test_differenced_process.py new file mode 100644 index 00000000..63c28073 --- /dev/null +++ b/src/test/test_differenced_process.py @@ -0,0 +1,271 @@ +""" +Unit tests for the DifferencedProcess class +""" + +import jax +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest +from numpy.testing import assert_array_almost_equal + +from pyrenew.deterministic import DeterministicVariable, NullVariable +from pyrenew.metaclass import DistributionalRV +from pyrenew.process import ( + DifferencedProcess, + IIDRandomSequence, + StandardNormalSequence, +) + + +@pytest.mark.parametrize( + "wrong_type_order", ["test", jnp.array([5.2]), 1.0, NullVariable()] +) +def test_differencing_order_type_validation(wrong_type_order): + """ + Test that passing something other than an + integer as the differencing_order raises + an error via the dedicated assertion function, + that valid types do pass validation, and + that this function is correctly used for + type validation at object instantiation. + """ + err_match = "must be an integer" + with pytest.raises(ValueError, match=err_match): + DifferencedProcess.assert_valid_differencing_order(wrong_type_order) + with pytest.raises(ValueError, match=err_match): + _ = DifferencedProcess( + fundamental_process=None, + differencing_order=wrong_type_order, + ) + DifferencedProcess.assert_valid_differencing_order(1) + _ = DifferencedProcess(fundamental_process=None, differencing_order=1) + + +@pytest.mark.parametrize( + ["wrong_value", "right_value"], [[0, 1], [-5, 5], [-10325235, 300]] +) +def test_differencing_order_value_validation(wrong_value, right_value): + """ + Test that passing an integer that is less than 1 + as the differencing_order raises a ValueError via + the dedicated assertion function, that valid + values do pass, and that the validation function + is correctly used for value validation at + object instantiation. + """ + with pytest.raises(ValueError, match="greater than or equal to 1"): + DifferencedProcess.assert_valid_differencing_order(wrong_value) + _ = DifferencedProcess( + fundamental_process=None, + differencing_order=wrong_value, + ) + + DifferencedProcess.assert_valid_differencing_order(right_value) + _ = DifferencedProcess( + fundamental_process=None, + differencing_order=right_value, + ) + + +@pytest.mark.parametrize( + ["order", "diffs"], + [ + [1, jnp.array([1.0, 2, -3])], + [2, jnp.array([1.0, 2, -3])], + [3, jnp.array([1.0, 2, -3])], + [4, jnp.array([1.0, 2, -3])], + ], +) +def test_integrator_init_validation(order, diffs): + """ + Test that when the integrator is called, + it succeeds if and only if the right number + of initial values have been specified, and raises + the appropriate ValueError otherwise. + """ + inits_short = jnp.ones(order - 1) + inits_correct = jnp.ones(order) + inits_long = jnp.ones(order + 1) + proc = DifferencedProcess( + fundamental_process=None, + differencing_order=order, + ) + with pytest.raises( + ValueError, match="exactly as many initial difference values" + ): + proc.integrate(inits_short, diffs) + with pytest.raises( + ValueError, match="exactly as many initial difference values" + ): + proc.integrate(inits_long, diffs) + proc.integrate(inits_correct, diffs) + + +@pytest.mark.parametrize( + ["order", "n_diffs"], [[1, 250], [2, 40], [3, 10], [4, 10], [5, 5]] +) +def test_integrator_correctness(order, n_diffs): + """ + Test that the scan-based integrate function built in + to DifferencedProcess works equivalently + to a manual implementation. + """ + diffs = jax.random.normal(key=jax.random.key(54), shape=(n_diffs,)) + inits = jax.random.normal(key=jax.random.key(45), shape=(order,)) + result_manual = diffs + for init in jnp.flip(inits): + result_manual = jnp.cumsum(jnp.hstack([init, result_manual])) + + proc = DifferencedProcess( + fundamental_process=None, differencing_order=order + ) + result_proc1 = proc.integrate(inits, diffs) + assert result_proc1.shape == (n_diffs + order,) + assert_array_almost_equal(result_manual, result_proc1, decimal=5) + assert result_proc1[0] == inits[0] + + +@pytest.mark.parametrize( + ["diffs", "inits", "expected_solution"], + [ + [ + jnp.array([0.25, 0.5, 0.5]), + jnp.array([0]), + jnp.array([0, 0.25, 0.75, 1.25]), + ], + [jnp.array([1, 1, 1]), jnp.array([0, 2]), jnp.array([0, 2, 5, 9, 14])], + ], +) +def test_manual_integrator_correctness(diffs, inits, expected_solution): + """ + Test the integrator correctness with manually computed + solutions. + """ + order = inits.size + proc = DifferencedProcess( + fundamental_process=None, differencing_order=order + ) + result = proc.integrate(inits, diffs) + assert_array_almost_equal(result, expected_solution) + + +@pytest.mark.parametrize( + ["fundamental_process", "differencing_order", "init_diff_vals"], + [ + [ + IIDRandomSequence( + DistributionalRV("element_dist", dist.Cauchy(0.02, 0.3)), + ), + 3, + jnp.array([0.25, 0.67, 5]), + ], + [ + StandardNormalSequence("test_stand_norm"), + 5, + jnp.array([0.23, 5.2, 1, 0.2, 3]), + ], + ], +) +def test_differenced_process_sample( + fundamental_process, differencing_order, init_diff_vals +): + """ + Test that differenced processes can be sampled, + that they yield the correct sample shapes, and that + they raise errors if non-feasible sample lengths are + requested. + """ + proc = DifferencedProcess( + differencing_order=differencing_order, + fundamental_process=fundamental_process, + ) + + n_long = differencing_order + 1032 + n_long_alt = differencing_order + 235 + n_one_diff = differencing_order + 1 + n_no_diffs = differencing_order + n_no_diffs_alt = differencing_order - 1 + n_fail = -1 + n_fail_alt = 0 + with numpyro.handlers.seed(rng_seed=6723): + samp, *_ = proc.sample(n=n_long, init_vals=init_diff_vals) + samp_alt, *_ = proc.sample(n=n_long_alt, init_vals=init_diff_vals) + samp_one_diff, *_ = proc.sample(n=n_one_diff, init_vals=init_diff_vals) + samp_no_diffs, *_ = proc.sample(n=n_no_diffs, init_vals=init_diff_vals) + samp_no_diffs_alt, *_ = proc.sample( + n=n_no_diffs_alt, init_vals=init_diff_vals + ) + assert samp.value.shape == (n_long,) + assert samp_alt.value.shape == (n_long_alt,) + assert samp_one_diff.value.shape == (n_one_diff,) + assert samp_no_diffs.value.shape == (n_no_diffs,) + assert samp_no_diffs_alt.value.shape == (n_no_diffs_alt,) + + with numpyro.handlers.seed(rng_seed=7834): + with pytest.raises(ValueError, match="must be positive"): + proc.sample(n=n_fail, init_vals=init_diff_vals) + with pytest.raises(ValueError, match="must be positive"): + proc.sample(n=n_fail_alt, init_vals=init_diff_vals) + with pytest.raises(ValueError, match="1-dimensional"): + proc.sample(n=n_long, init_vals=jnp.atleast_2d(init_diff_vals)) + + +@pytest.mark.parametrize( + ["fundamental_process", "inits", "n", "expected_solution"], + [ + [ + IIDRandomSequence( + element_rv=DeterministicVariable("zero", jnp.array(0.0)), + ), + jnp.array([0.0, 0, 0, 0, 0]), + 3, + jnp.array([0.0, 0.0, 0.0]), + ], + [ + IIDRandomSequence( + element_rv=DeterministicVariable("zero", jnp.array(1.0)), + ), + jnp.array([0]), + 5, + jnp.array([0.0, 1.0, 2.0, 3.0, 4.0]), + ], + [ + IIDRandomSequence( + element_rv=DeterministicVariable("zero", jnp.array(1.0)), + ), + jnp.array([0, 1]), + 7, + jnp.array([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0]), + ], + [ + IIDRandomSequence( + element_rv=DeterministicVariable("zero", jnp.array(1.0)), + ), + jnp.array([0, 1]), + 1, + jnp.array([0.0]), + ], + [ + IIDRandomSequence( + element_rv=DeterministicVariable("zero", jnp.array(1.0)), + ), + jnp.array([0, 1]), + 2, + jnp.array([0.0, 1.0]), + ], + ], +) +def test_manual_difference_process_sample( + fundamental_process, inits, n, expected_solution +): + """ + Test the correctness of DifferencedProcess.sample() + with manually computed solutions + """ + proc = DifferencedProcess( + differencing_order=len(inits), + fundamental_process=fundamental_process, + ) + result, *_ = proc.sample(n=n, init_vals=inits) + assert_array_almost_equal(result.value, expected_solution) diff --git a/src/test/test_distributional_rv.py b/src/test/test_distributional_rv.py index 6d0f2ff5..0a0b4d2c 100644 --- a/src/test/test_distributional_rv.py +++ b/src/test/test_distributional_rv.py @@ -6,6 +6,7 @@ import numpyro.distributions as dist import pytest from numpy.testing import assert_array_equal +from numpyro.distributions import ExpandedDistribution from pyrenew.metaclass import ( DistributionalRV, @@ -76,6 +77,53 @@ def test_factory_triage(valid_static_dist_arg, valid_dynamic_dist_arg): assert isinstance(dynamic, DynamicDistributionalRV) +@pytest.mark.parametrize( + ["dist", "params", "expand_by_shape"], + [ + [dist.Normal, {"loc": 0.0, "scale": 0.5}, (5,)], + [dist.Poisson, {"rate": 0.35265}, (20, 25)], + [ + dist.Cauchy, + { + "loc": jnp.array([1.0, 5.0, -0.25]), + "scale": jnp.array([0.02, 0.15, 2]), + }, + (10, 10, 3), + ], + ], +) +def test_expand_by(dist, params, expand_by_shape): + """ + Test the expand_by method for static + distributional RVs. + """ + static = DistributionalRV(name="static", distribution=dist(**params)) + dynamic = DistributionalRV(name="dynamic", distribution=dist) + expanded_static = static.expand_by(expand_by_shape) + expanded_dynamic = dynamic.expand_by(expand_by_shape) + + assert isinstance(expanded_dynamic, DynamicDistributionalRV) + assert dynamic.expand_by_shape is None + assert isinstance(expanded_dynamic.expand_by_shape, tuple) + assert expanded_dynamic.expand_by_shape == expand_by_shape + assert dynamic.reparam_dict == expanded_dynamic.reparam_dict + assert ( + dynamic.distribution_constructor + == expanded_dynamic.distribution_constructor + ) + + assert isinstance(expanded_static, StaticDistributionalRV) + assert isinstance(expanded_static.distribution, ExpandedDistribution) + assert expanded_static.distribution.batch_shape == ( + expand_by_shape + static.distribution.batch_shape + ) + + with pytest.raises(ValueError): + dynamic.expand_by("not a tuple") + with pytest.raises(ValueError): + static.expand_by("not a tuple") + + @pytest.mark.parametrize( ["dist", "params"], [ diff --git a/src/test/test_first_difference_ar.py b/src/test/test_first_difference_ar.py deleted file mode 100755 index 7c3692a9..00000000 --- a/src/test/test_first_difference_ar.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -# numpydoc ignore=GL08 - -import jax.numpy as jnp -import numpyro - -from pyrenew.process import FirstDifferenceARProcess - - -def test_fd_ar_can_be_sampled(): - """ - Check that stochastic process - with AR(1) first differences - can be initialized and sampled - from - """ - ar_fd = FirstDifferenceARProcess("trend_rw", 0.5, 0.5) - - with numpyro.handlers.seed(rng_seed=62): - # can sample with and without inits - # for the rate of change - ans0 = ar_fd(duration=3532, init_val=jnp.array([50.0])) - ans1 = ar_fd( - duration=3532, - init_val=jnp.array([50.0]), - init_rate_of_change=jnp.array([0.25]), - ) - - # Checking proper shape - assert ans0[0].value.shape == (3532,) - assert ans1[0].value.shape == (3532,) diff --git a/src/test/test_forecast.py b/src/test/test_forecast.py index 1854aa24..beef0273 100644 --- a/src/test/test_forecast.py +++ b/src/test/test_forecast.py @@ -1,6 +1,6 @@ # numpydoc ignore=GL08 -from test.utils import simple_rt +from test.utils import SimpleRt import jax.numpy as jnp import jax.random as jr @@ -20,7 +20,10 @@ def test_forecast(): - """Check that forecasts are the right length and match the posterior up until forecast begins.""" + """ + Check that forecasts are the right length and + match the posterior up until forecast begins. + """ pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( @@ -31,7 +34,7 @@ def test_forecast(): ) latent_infections = Infections() observed_infections = PoissonObservation(name="poisson_rv") - rt = simple_rt() + rt = SimpleRt() model = RtInfectionsRenewalModel( I0_rv=I0, @@ -63,8 +66,8 @@ def test_forecast(): == n_datapoints + n_forecast_points ) - # Check the first elements of the posterior predictive Rt are the same as the - # posterior Rt + # Check the first elements of the posterior predictive Rt + # are the same as the posterior Rt assert_array_equal( model.mcmc.get_samples()["Rt"][0], posterior_predictive_samples["Rt"][0][ diff --git a/src/test/test_iid_random_sequence.py b/src/test/test_iid_random_sequence.py new file mode 100755 index 00000000..eb6d943c --- /dev/null +++ b/src/test/test_iid_random_sequence.py @@ -0,0 +1,80 @@ +# numpydoc ignore=GL08 + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest +from scipy.stats import kstest + +from pyrenew.metaclass import ( + DistributionalRV, + SampledValue, + StaticDistributionalRV, +) +from pyrenew.process import IIDRandomSequence, StandardNormalSequence + + +@pytest.mark.parametrize( + ["distribution", "n"], + [ + [dist.Normal(0, 1), 1000], + [dist.Cauchy(2, 325.0), 13532], + [dist.Normal(jnp.array([2.0, 3.0, -5.235]), 0.25), 622], + ], +) +def test_iidrandomsequence_with_dist_rv(distribution, n): + """ + Check that an IIDRandomSequence can be + initialized and sampled from when the element_rv is + a distributional RV, including with array-valued + distributions + """ + element_rv = DistributionalRV("el_rv", distribution=distribution) + rseq = IIDRandomSequence(element_rv=element_rv) + if distribution.batch_shape == () or distribution.batch_shape == (1,): + expected_shape = (n,) + else: + expected_shape = tuple([n] + [x for x in distribution.batch_shape]) + + with numpyro.handlers.seed(rng_seed=62): + ans_vec, *_ = rseq.sample(n=n, vectorize=True) + ans_serial, *_ = rseq.sample(n=n, vectorize=False) + + # check that samples are the right type + for ans in [ans_serial, ans_vec]: + assert isinstance(ans, SampledValue) + # check that the samples are of the right shape + assert ans.value.shape == expected_shape + + # vectorized and unvectorized sampling should + # not give the same answer + # but they should give similar distributions + assert all(ans_serial.value.flatten() != ans_vec.value.flatten()) + + if expected_shape == (n,): + kstest_out = kstest(ans_serial.value, ans_vec.value) + assert kstest_out.pvalue > 0.01 + + +def test_standard_normal_sequence(): + """ + Test the StandardNormalSequence RandomVariable + class. + """ + norm_seq = StandardNormalSequence("test_norm_elements") + + # should be implemented with a DistributionalRV + # that is a standard normal + assert isinstance(norm_seq.element_rv, StaticDistributionalRV) + assert isinstance(norm_seq.element_rv.distribution, dist.Normal) + assert norm_seq.element_rv.distribution.loc == 0.0 + assert norm_seq.element_rv.distribution.scale == 1.0 + + # should be sampleable + with numpyro.handlers.seed(rng_seed=67): + ans, *_ = norm_seq.sample(n=50000) + + assert isinstance(ans, SampledValue) + # samples should be approximately standard normal + kstest_out = kstest(ans.value, "norm", (0, 1)) + assert kstest_out.pvalue > 0.01 diff --git a/src/test/test_infection_seeding_method.py b/src/test/test_infection_initialization_method.py similarity index 96% rename from src/test/test_infection_seeding_method.py rename to src/test/test_infection_initialization_method.py index afb4a0ad..4ca77510 100644 --- a/src/test/test_infection_seeding_method.py +++ b/src/test/test_infection_initialization_method.py @@ -85,10 +85,14 @@ def test_initialize_infections_zero_pad(): infections = InitializeInfectionsZeroPad( n_timepoints ).initialize_infections(I_pre_init) - testing.assert_array_equal( - infections, np.pad(I_pre_init, (n_timepoints - I_pre_init.size, 0)) + + manual_infections = np.pad( + np.atleast_1d(I_pre_init), + (n_timepoints - np.array(I_pre_init).size, 0), ) + testing.assert_array_equal(infections, manual_infections) + I_pre_init_RV_2 = DeterministicVariable( name="I_pre_init_RV", value=np.array([10.0, 10.0]) ) diff --git a/src/test/test_infection_seeding_process.py b/src/test/test_infection_initialization_process.py similarity index 100% rename from src/test/test_infection_seeding_process.py rename to src/test/test_infection_initialization_process.py diff --git a/src/test/test_latent_admissions.py b/src/test/test_latent_admissions.py index 92acec2b..564f11f5 100644 --- a/src/test/test_latent_admissions.py +++ b/src/test/test_latent_admissions.py @@ -1,20 +1,16 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 +from test.utils import SimpleRt + import jax.numpy as jnp import numpy.testing as testing import numpyro import numpyro.distributions as dist -from pyrenew import transformation as t from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import ( - DistributionalRV, - SampledValue, - TransformedRandomVariable, -) -from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import DistributionalRV, SampledValue def test_admissions_sample(): @@ -25,22 +21,10 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions - rt = TransformedRandomVariable( - name="Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", distribution=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", distribution=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), - ) + rt = SimpleRt() with numpyro.handlers.seed(rng_seed=223): - sim_rt = rt(n_steps=30)[0].value + sim_rt = rt(n=30)[0].value gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) i0 = 10 * jnp.ones_like(gen_int) @@ -91,8 +75,6 @@ def test_admissions_sample(): sim_hosp_1.latent_hospital_admissions.value, inf_sampled1[0].value, ) - - # Testing the offset in the observed data inf_hosp2 = jnp.ones(30) inf_hosp2 = DeterministicPMF("i2h", inf_hosp2 / sum(inf_hosp2)) @@ -143,6 +125,3 @@ def test_admissions_sample(): sim_hosp_2a[2 : (sim_hosp_2b.size - 2)], sim_hosp_2b[: (sim_hosp_2b.size - 4)], ) - - -test_admissions_sample() diff --git a/src/test/test_latent_infections.py b/src/test/test_latent_infections.py index 7fb752ad..7fc500f5 100755 --- a/src/test/test_latent_infections.py +++ b/src/test/test_latent_infections.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 -from test.utils import simple_rt +from test.utils import SimpleRt import jax.numpy as jnp import numpy.testing as testing @@ -17,10 +17,10 @@ def test_infections_as_deterministic(): the same seed is used. """ - rt = simple_rt() + rt = SimpleRt() with numpyro.handlers.seed(rng_seed=223): - sim_rt, *_ = rt(n_steps=30) + sim_rt, *_ = rt(n=30) gen_int = jnp.array([0.25, 0.25, 0.25, 0.25]) diff --git a/src/test/test_model_basic_renewal.py b/src/test/test_model_basic_renewal.py index 8235ca7f..ffe09cd4 100644 --- a/src/test/test_model_basic_renewal.py +++ b/src/test/test_model_basic_renewal.py @@ -2,7 +2,7 @@ # numpydoc ignore=GL08 -from test.utils import simple_rt +from test.utils import SimpleRt import jax.numpy as jnp import jax.random as jr @@ -34,16 +34,21 @@ def test_model_basicrenewal_no_timepoints_or_observations(): name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25]) ) - I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) + I0_init_rv = InfectionInitializationProcess( + "I0_initialization", + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), + t_unit=1, + ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") - rt = simple_rt() + rt = SimpleRt() model1 = RtInfectionsRenewalModel( - I0_rv=I0, + I0_rv=I0_init_rv, gen_int_rv=gen_int, latent_infections_rv=latent_infections, infection_obs_process_rv=observed_infections, @@ -65,16 +70,21 @@ def test_model_basicrenewal_both_timepoints_and_observations(): value=jnp.array([0.25, 0.25, 0.25, 0.25]), ) - I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) + I0_init_rv = InfectionInitializationProcess( + "I0_initialization", + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), + t_unit=1, + ) latent_infections = Infections() observed_infections = PoissonObservation("possion_rv") - rt = simple_rt() + rt = SimpleRt() model1 = RtInfectionsRenewalModel( - I0_rv=I0, + I0_rv=I0_init_rv, gen_int_rv=gen_int, latent_infections_rv=latent_infections, infection_obs_process_rv=observed_infections, @@ -101,9 +111,9 @@ def test_model_basicrenewal_no_obs_model(): ) with pytest.raises(ValueError): - I0 = DistributionalRV(name="I0", distribution=1) + _ = DistributionalRV(name="I0", distribution=1) - I0 = InfectionInitializationProcess( + I0_init_rv = InfectionInitializationProcess( "I0_initialization", DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), @@ -112,11 +122,11 @@ def test_model_basicrenewal_no_obs_model(): latent_infections = Infections() - rt = simple_rt() + rt = SimpleRt() model0 = RtInfectionsRenewalModel( gen_int_rv=gen_int, - I0_rv=I0, + I0_rv=I0_init_rv, latent_infections_rv=latent_infections, Rt_process_rv=rt, # Explicitly use None, this should call the NullObservation @@ -174,7 +184,7 @@ def test_model_basicrenewal_with_obs_model(): name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25]) ) - I0 = InfectionInitializationProcess( + I0_init_rv = InfectionInitializationProcess( "I0_initialization", DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), @@ -185,10 +195,10 @@ def test_model_basicrenewal_with_obs_model(): observed_infections = PoissonObservation("poisson_rv") - rt = simple_rt() + rt = SimpleRt() model1 = RtInfectionsRenewalModel( - I0_rv=I0, + I0_rv=I0_init_rv, gen_int_rv=gen_int, latent_infections_rv=latent_infections, infection_obs_process_rv=observed_infections, @@ -199,6 +209,11 @@ def test_model_basicrenewal_with_obs_model(): with numpyro.handlers.seed(rng_seed=223): model1_samp = model1.sample(n_datapoints=30) + print(model1_samp) + print(model1_samp.Rt.value.size) + print(model1_samp.latent_infections.value.size) + print(model1_samp.observed_infections.value.size) + model1.run( num_warmup=500, num_samples=500, @@ -223,7 +238,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25]) ) - I0 = InfectionInitializationProcess( + I0_init_rv = InfectionInitializationProcess( "I0_initialization", DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), @@ -234,18 +249,18 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 observed_infections = PoissonObservation("poisson_rv") - rt = simple_rt() + rt = SimpleRt() model1 = RtInfectionsRenewalModel( - I0_rv=I0, + I0_rv=I0_init_rv, gen_int_rv=gen_int, latent_infections_rv=latent_infections, infection_obs_process_rv=observed_infections, Rt_process_rv=rt, ) - # Sampling and fitting model 1 (with obs infections) pad_size = 5 + with numpyro.handlers.seed(rng_seed=223): model1_samp = model1.sample(n_datapoints=30, padding=pad_size) @@ -254,7 +269,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 num_samples=500, rng_key=jr.key(22), data_observed_infections=model1_samp.observed_infections.value, - padding=5, + padding=pad_size, ) inf = model1.spread_draws(["all_latent_infections"]) diff --git a/src/test/test_model_hosp_admissions.py b/src/test/test_model_hosp_admissions.py index fc2e4f57..4d20815f 100644 --- a/src/test/test_model_hosp_admissions.py +++ b/src/test/test_model_hosp_admissions.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 -from test.utils import simple_rt + +from test.utils import SimpleRt import jax.numpy as jnp import jax.random as jr @@ -67,7 +68,7 @@ def test_model_hosp_no_timepoints_or_observations(): I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() - Rt_process = simple_rt() + Rt_process = SimpleRt() observed_admissions = PoissonObservation("poisson_rv") @@ -132,7 +133,7 @@ def test_model_hosp_both_timepoints_and_observations(): I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() - Rt_process = simple_rt() + Rt_process = SimpleRt() observed_admissions = PoissonObservation("poisson_rv") @@ -205,7 +206,7 @@ def test_model_hosp_no_obs_model(): ) latent_infections = Infections() - Rt_process = simple_rt() + Rt_process = SimpleRt() inf_hosp = DeterministicPMF( name="inf_hosp", @@ -315,7 +316,7 @@ def test_model_hosp_with_obs_model(): ) latent_infections = Infections() - Rt_process = simple_rt() + Rt_process = SimpleRt() observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -402,7 +403,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): ) latent_infections = Infections() - Rt_process = simple_rt() + Rt_process = SimpleRt() observed_admissions = PoissonObservation("poisson_rv") inf_hosp = DeterministicPMF( @@ -495,7 +496,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) latent_infections = Infections() - Rt_process = simple_rt() + Rt_process = SimpleRt() observed_admissions = PoissonObservation("poisson_rv") diff --git a/src/test/test_predictive.py b/src/test/test_predictive.py index f655ce96..5c76b98b 100644 --- a/src/test/test_predictive.py +++ b/src/test/test_predictive.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- """ -Ensures that posterior predictive samples are not generated when no posterior samples are available. +Ensures that posterior predictive samples are not generated +when no posterior samples are available. """ -from test.utils import simple_rt +from test.utils import SimpleRt import jax.numpy as jnp import numpyro.distributions as dist @@ -30,7 +31,7 @@ ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") -rt = simple_rt() +rt = SimpleRt() model = RtInfectionsRenewalModel( I0_rv=I0, diff --git a/src/test/test_random_key.py b/src/test/test_random_key.py index e7fe03f7..6d6cfd43 100644 --- a/src/test/test_random_key.py +++ b/src/test/test_random_key.py @@ -5,7 +5,7 @@ with different random keys behave appropriately. """ -from test.utils import simple_rt +from test.utils import SimpleRt import jax.numpy as jnp import jax.random as jr @@ -35,7 +35,7 @@ def create_test_model(): # numpydoc ignore=GL08 ) latent_infections = Infections() observed_infections = PoissonObservation("poisson_rv") - rt = simple_rt() + rt = SimpleRt() model = RtInfectionsRenewalModel( I0_rv=I0, gen_int_rv=gen_int, diff --git a/src/test/test_random_walk.py b/src/test/test_random_walk.py index 77583704..d7e2cabd 100755 --- a/src/test/test_random_walk.py +++ b/src/test/test_random_walk.py @@ -3,97 +3,113 @@ import jax.numpy as jnp import numpyro import numpyro.distributions as dist -from numpy.testing import assert_almost_equal +import pytest +from numpy.testing import assert_almost_equal, assert_array_almost_equal from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import DistributionalRV -from pyrenew.process import SimpleRandomWalkProcess - - -def test_rw_can_be_sampled(): +from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.process import RandomWalk, StandardNormalRandomWalk + + +@pytest.mark.parametrize( + ["element_rv", "init_value"], + [ + [DistributionalRV("test_normal", dist.Normal(0.5, 1)), 50.0], + [DistributionalRV("test_cauchy", dist.Cauchy(0.25, 0.25)), -3], + ["test standard normal", jnp.array(3)], + ], +) +def test_rw_can_be_sampled(element_rv, init_value): """ - Check that a simple random walk + Check that a RandomWalk and a StandardNormalRandomWalk can be initialized and sampled from """ - init_rv_rand = DistributionalRV( - name="init_rv_rand", - distribution=dist.Normal(1, 0.5), - ) - init_rv_fixed = DeterministicVariable(name="init_rv_fixed", value=50.0) - - step_rv = DistributionalRV( - name="rw_step", - distribution=dist.Normal(0, 1), - ) + init_rv = DeterministicVariable(name="init_rv_fixed", value=init_value) - rw_init_rand = SimpleRandomWalkProcess( - "rw_rand_init", step_rv=step_rv, init_rv=init_rv_rand - ) - - rw_init_fixed = SimpleRandomWalkProcess( - "rw_fixed_init", step_rv=step_rv, init_rv=init_rv_fixed - ) + if isinstance(element_rv, RandomVariable): + rw = RandomWalk(element_rv) + elif element_rv == "test standard normal": + rw = StandardNormalRandomWalk("std_normal_step") + else: + raise ValueError("Unexpected element_rv") with numpyro.handlers.seed(rng_seed=62): # can sample with a fixed init # and with a random init - ans_rand = rw_init_rand(n_steps=3532) - ans_fixed = rw_init_fixed(n_steps=5023) - + init_vals = init_rv()[0].value + ans_long = rw(n=5023, init_vals=init_vals) + ans_short = rw(n=1, init_vals=init_vals) + + # Providing more than one init val should + # raise an error. + with pytest.raises(ValueError, match="differencing order"): + rw(n=523, init_vals=jnp.hstack([init_vals, 0.25])) # check that the samples are of the right shape - assert ans_rand[0].value.shape == (3532,) - assert ans_fixed[0].value.shape == (5023,) - - # check that fixing inits works - assert_almost_equal(ans_fixed[0].value[0], init_rv_fixed.value) - assert ans_rand[0].value[0] != init_rv_fixed.value + assert ans_long[0].value.shape == (5023,) + assert ans_short[0].value.shape == (1,) + # check that the first n_inits samples are the inits + n_inits = jnp.atleast_1d(init_vals).size + assert_array_almost_equal( + ans_long[0].value[0:n_inits], jnp.atleast_1d(init_vals) + ) + assert_array_almost_equal( + ans_short[0].value, jnp.atleast_1d(init_vals)[:1] + ) -def test_rw_samples_correctly_distributed(): +@pytest.mark.parametrize( + ["step_mean", "step_sd"], + [ + [0, 1], + [0, 0.25], + [2.253, 0.025], + [-3.2521, 1], + [1052, 3], + [1e-6, 0.02], + ], +) +def test_normal_rw_samples_correctly_distributed(step_mean, step_sd): """ - Check that a simple random walk has steps - distributed according to the target distribution + Check that Normal random walks have steps + distributed according to the target Normal distributions, + including the StandardNormalRandomWalk. """ n_samples = 10000 - for step_mean, step_sd in zip( - [0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02] - ): - rw_init_val = 532.0 - rw_normal = SimpleRandomWalkProcess( - name="rw_normal_test", + rw_init_val = jnp.array([532.0]) + if step_mean == 0 and step_sd == 1: + rw_normal = StandardNormalRandomWalk("test standard normal") + else: + rw_normal = RandomWalk( step_rv=DistributionalRV( - name="rw_normal_dist", + name="rw_step_dist", distribution=dist.Normal(loc=step_mean, scale=step_sd), ), - init_rv=DeterministicVariable( - name="init_rv_fixed", value=rw_init_val - ), ) - with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal(n_steps=n_samples) - samples = samples.value - - # Checking the shape - assert samples.shape == (n_samples,) - - # diffs should not be greater than - # 5 sigma - diffs = jnp.diff(samples) - assert jnp.all(jnp.abs(diffs - step_mean) < 5 * step_sd) - - # sample mean of diffs should be - # approximately equal to the - # step mean, according to - # the Law of Large Numbers - deviation_threshold = 4 * jnp.sqrt((step_sd**2) / n_samples) - assert jnp.abs(jnp.mean(diffs) - step_mean) < deviation_threshold - - # sample sd of diffs - # should be approximately equal - # to the step sd - assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1) - - # first value should be the init value - assert_almost_equal(samples[0], rw_init_val) + with numpyro.handlers.seed(rng_seed=62): + samples, *_ = rw_normal(n=n_samples, init_vals=rw_init_val) + samples = samples.value + + # Checking the shape + assert samples.shape == (n_samples,) + + # diffs should not be greater than + # 5 sigma + diffs = jnp.diff(samples) + assert jnp.all(jnp.abs(diffs - step_mean) < 5 * step_sd) + + # sample mean of diffs should be + # approximately equal to the + # step mean, according to + # the Law of Large Numbers + deviation_threshold = 4 * jnp.sqrt((step_sd**2) / n_samples) + assert jnp.abs(jnp.mean(diffs) - step_mean) < deviation_threshold + + # sample sd of diffs + # should be approximately equal + # to the step sd + assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1) + + # first value should be the init value + assert_almost_equal(samples[0], rw_init_val) diff --git a/src/test/test_rtperiodicdiff.py b/src/test/test_rtperiodicdiff.py index 9927be7c..8d1ac28a 100644 --- a/src/test/test_rtperiodicdiff.py +++ b/src/test/test_rtperiodicdiff.py @@ -1,50 +1,14 @@ """ -Test the rtweeklydiff module +Test the RtPeriodicDiffARProcess module """ import jax.numpy as jnp -import numpy as np import numpyro -from jax import lax -from jax.typing import ArrayLike -from numpy.testing import assert_array_almost_equal, assert_array_equal +import pytest +from numpy.testing import assert_array_equal from pyrenew.deterministic import DeterministicVariable -from pyrenew.process import RtWeeklyDiffProcess - - -def _manual_rt_weekly_diff( - log_seed: ArrayLike, sd: ArrayLike, b: ArrayLike -) -> ArrayLike: - """ - 'Manually' (for-loop) reconstruct the scanner process - - Parameters - ---------- - log_seed : ArrayLike - The initial log seed. - sd : ArrayLike - The noise. - b : ArrayLike - The autoregressive parameter. - - Returns - ------- - ArrayLike - The reconstructed RtWeeklyDiffProcess process. - """ - - log_ans = np.zeros(sd.size + 2) - log_ans[0] = log_seed[0] - log_ans[1] = log_seed[1] - for i in range(2, sd.size + 2): - log_ans[i] = ( - log_ans[i - 1] - + b[0] * (log_ans[i - 1] - log_ans[i - 2]) - + sd[i - 2] - ) - - return log_ans[2:] +from pyrenew.process import RtWeeklyDiffARProcess def test_rtweeklydiff() -> None: @@ -65,7 +29,7 @@ def test_rtweeklydiff() -> None: } duration = 30 - rtwd = RtWeeklyDiffProcess(**params) + rtwd = RtWeeklyDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt = rtwd(duration=duration).rt.value @@ -80,7 +44,7 @@ def test_rtweeklydiff() -> None: # Checking start off a different day of the week params["offset"] = 5 - rtwd = RtWeeklyDiffProcess(**params) + rtwd = RtWeeklyDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt2 = rtwd(duration=duration).rt.value @@ -114,7 +78,7 @@ def test_rtweeklydiff_no_autoregressive() -> None: ), } - rtwd = RtWeeklyDiffProcess(**params) + rtwd = RtWeeklyDiffARProcess(**params) duration = 1000 @@ -136,46 +100,10 @@ def test_rtweeklydiff_no_autoregressive() -> None: return None -def test_rtweeklydiff_manual_reconstruction() -> None: - """Checks that the 'manual' reconstruction is correct""" - - params = { - "name": "test", - "offset": 0, - "log_rt_rv": DeterministicVariable( - name="log_rt", - value=jnp.array([0.1, 0.2]), - ), - "autoreg_rv": DeterministicVariable( - name="autoreg_rv", value=jnp.array([0.7]) - ), - "periodic_diff_sd_rv": DeterministicVariable( - name="periodic_diff_sd_rv", - value=jnp.array([0.1]), - ), - } - - rtwd = RtWeeklyDiffProcess(**params) - b = jnp.array([0.7]) - - noise = np.random.normal(0, 0.1, 10) - - _, ans0 = lax.scan( - f=rtwd.autoreg_process, - init=np.hstack([params["log_rt_rv"]()[0].value, b]), - xs=noise, - ) - - ans1 = _manual_rt_weekly_diff( - log_seed=params["log_rt_rv"]()[0].value, sd=noise, b=b - ) - - assert_array_almost_equal(ans0, ans1) - - return None - - -def test_rtperiodicdiff_smallsample(): +@pytest.mark.parametrize( + "inits", [jnp.array([0.1, 0.2]), jnp.array([0.5, 0.7])] +) +def test_rtperiodicdiff_smallsample(inits): """Checks basic functionality of the process with a small sample size.""" params = { @@ -183,7 +111,7 @@ def test_rtperiodicdiff_smallsample(): "offset": 0, "log_rt_rv": DeterministicVariable( name="log_rt", - value=jnp.array([0.1, 0.2]), + value=inits, ), "autoreg_rv": DeterministicVariable( name="autoreg_rv", value=jnp.array([0.7]) @@ -194,7 +122,7 @@ def test_rtperiodicdiff_smallsample(): ), } - rtwd = RtWeeklyDiffProcess(**params) + rtwd = RtWeeklyDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt = rtwd(duration=6).rt.value @@ -203,4 +131,5 @@ def test_rtperiodicdiff_smallsample(): assert rt.shape == (6,) # Check that all values in rt are the same - assert jnp.all(rt == rt[0]) + # as the first initial value + assert jnp.all(rt == jnp.exp(inits[0])) diff --git a/src/test/utils.py b/src/test/utils.py index 01f4e0de..be551dfe 100644 --- a/src/test/utils.py +++ b/src/test/utils.py @@ -7,37 +7,64 @@ import numpyro.distributions as dist import pyrenew.transformation as t -from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable -from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import ( + DistributionalRV, + RandomVariable, + SampledValue, + TransformedRandomVariable, +) +from pyrenew.process import RandomWalk -def simple_rt(arg_name: str = "Rt_rv"): +class SimpleRt(RandomVariable): """ - Helper function to create a default Rt + Helper class to create a default Rt RandomVariable for testing. - - Parameters - ----------- - arg_name : str - Name assigned to the randonvariable. - If None, then defaults to "Rt_rv" - - Returns - ------- - TransformedRandomVariable : - A log-scale random walk with fixed - init value and step size priors """ - return TransformedRandomVariable( - arg_name, - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", distribution=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", distribution=dist.Normal(0, 0.2) + + def __init__(self, name: str = "Rt_rv"): + """ + Default constructor + + Parameters + ----------- + name : str + Name assigned to the RandomVariable. + If None, then defaults to "Rt_rv" + + Returns + ------- + None + """ + self.name = name + self.rt_rv_ = TransformedRandomVariable( + name=f"{name}_log_rt_random_walk", + base_rv=RandomWalk( + name="log_rt", + step_rv=DistributionalRV( + name="rw_step_rv", distribution=dist.Normal(0, 0.025) + ), ), - ), - transforms=t.ExpTransform(), - ) + transforms=t.ExpTransform(), + ) + self.rt_init_rv_ = DistributionalRV( + name=f"{name}_init_log_rt", distribution=dist.Normal(0, 0.2) + ) + + def sample(self, n=None, **kwargs) -> SampledValue: + """ + Sample method + + Returns + ------- + SampledValue + """ + init_rt, *_ = self.rt_init_rv_.sample() + return self.rt_rv_(init_vals=init_rt.value, n=n) + + @staticmethod + def validate(self): + """ + No validation. + """ + pass