Skip to content

Commit

Permalink
Refactor differenced processes, random walks, and AR processes (#380)
Browse files Browse the repository at this point in the history
* Checkpoint commit on differencedar

* Checkpoint commit on differencedar

* Add de-differencing helper and tests

* Replace fori_loop with scan

* Fix bugs and typos in integrator, add test

* Fix name of StandardNormalRandomWalk

* Try two random walk impelmentations

* Fix bug in differenced process

* Fix some tests, add dynamic versus static DistributionalRVs

* Split distributional RVs into static and dynamic

* Update DistributionalRV kwarg dist => distribution in all tests

* update dist => distribution kwarg in DistributionalRV in all tutorials

* Add tests for DistributionalRV factory and classes

* Allow scalars to play better with scan, give more informative error messages for incompatible arrays in certain scan functions

* Autoformat files

* Rename infection_initialization tests to be consistent with name of tested classes, refactor all tests to work with new process module, ensure that DifferencedProcess samples the fundamental process n-1 times to account for inits, use pytest.mark.parametrize in rw sample distribution test

* Fix typo

* Refactor AR and update tests

* Refactor RtPeriodicDiff to use DifferencedProcess and ARProcess; delete unused manual AR function

* Default names for now

* Fundamental process init passthrough for differencedprocess, update tutorials

* Fix typo in IIDRandomSequence class name, add test for class

* Apply suggestions from code review

Co-authored-by: Damon Bayer <[email protected]>

* Reintroduce test_rtperiodicdiff, removing manual reconstruction test that did not use the sampling method

* Coerce to 1d in the appropriate place in infectionswithfeedback

* Reintroduce padding

* Coerce to array in tutorial

* Fix tutorial bug

* Update model/src/pyrenew/process/differencedprocess.py

Co-authored-by: Damon Bayer <[email protected]>

* Restore padding in test_model_basic_renewal

* Restore padding in test in a couple other places

* Apply suggestions from code review

revert to using gen_int.size() method

Co-authored-by: Damon Bayer <[email protected]>

* Convert sds to scalars in test ar process

* Rename PeriodicDiff classes to highlight that they are DiffAR processes

* Move model up one level

* Add expand method for distributional rvs

* Mathtext and refs for differencedprocess class

* Mathtext typo fixes and doc improvements for DifferencedProcess

* More sphinx tweaks

* More typos

* Test for standard normal sequence

* Float expectations

* Add scipy to testing deps

* Update pyproject for deptry

* scipy allowed to be dev dep

* Add test for vectorized sampling

* Raise ValueError if noise sd not scalar

* Fix expand_by() tests to use new syntax

* Manual tests for integrator correctness, better documentation of output shape, revise sample method so that n is chosen properly

* Fix distributional rv test

* Raise eror for overly short DifferenceProcess samples; test for that error raise

* Force differencedprocess to deal with 1D diffs and error otherwise, make it behave as expected for 0 < n < order

* Fix scipy dep

* Update test

* Improve error message

* Autoformat files

* Update tutorials

* More checks and clearer code for differencedprocess.py

* Update src/pyrenew/process/differencedprocess.py

Update link to Hyndman textbook

Co-authored-by: Damon Bayer <[email protected]>

* Update src/pyrenew/process/differencedprocess.py

Fix out of date inaccurate docstring

Co-authored-by: Damon Bayer <[email protected]>

* Update src/pyrenew/process/differencedprocess.py

Co-authored-by: Damon Bayer <[email protected]>

* Remove names for DifferencedProcess

* Change refs to expected shape

* One missing edit in test

* Restore incorrectly removed test

* Update random walk tests

* Make AR strict about 1d input arrays

* Configurable AR process name in rtperiodicdiffar class

* Added check comparing to first initial value to small sample test

* Remove unused function

* Remove required names and adjust tests and structure accordingly

* replace :fun: sphinx directives

* Fix one more :fun:

* And yet one more :fun:

---------

Co-authored-by: Damon Bayer <[email protected]>
  • Loading branch information
dylanhmorris and damonbayer authored Aug 22, 2024
1 parent 7ad81f1 commit a793417
Show file tree
Hide file tree
Showing 38 changed files with 1,500 additions and 708 deletions.
25 changes: 12 additions & 13 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from pyrenew.process import SimpleRandomWalkProcess
from pyrenew.process import RandomWalk
from pyrenew.latent import (
Infections,
InfectionInitializationProcess,
Expand Down Expand Up @@ -68,7 +68,7 @@ flowchart LR
end
subgraph process[Process module]
rt["Rt_process_rv\n(Custom class built using SimpleRandomWalk)"]
rt["Rt_process_rv\n(Custom class built using RandomWalk)"]
end
subgraph deterministic[Deterministic module]
Expand Down Expand Up @@ -139,26 +139,25 @@ class MyRt(RandomVariable):
def validate(self):
pass
def sample(self, n_steps: int, **kwargs) -> tuple:
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt_random_walk",
base_rv=RandomWalk(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv",
distribution=dist.Normal(0, sd_rt),
reparam=LocScaleReparam(0),
),
init_rv=DistributionalRV(
name="init_log_rt",
distribution=dist.Normal(jnp.log(1), jnp.log(1.2)),
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=t.ExpTransform(),
)
return rt_rv.sample(n_steps=n_steps, **kwargs)
rt_init_rv = DistributionalRV(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
rt_proc = MyRt()
Expand Down
10 changes: 5 additions & 5 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class MyRt(metaclass.RandomVariable):
def validate(self):
pass
def sample(self, n_steps: int, **kwargs) -> tuple:
def sample(self, n: int, **kwargs) -> tuple:
# Standard deviation of the random walk
sd_rt, *_ = self.sd_rv()
Expand All @@ -115,15 +115,14 @@ class MyRt(metaclass.RandomVariable):
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
)
init_rv = metaclass.DistributionalRV(
rt_init_rv = metaclass.DistributionalRV(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
# Random walk process
base_rv = process.SimpleRandomWalkProcess(
base_rv = process.RandomWalk(
name="log_rt",
step_rv=step_rv,
init_rv=init_rv,
)
# Transforming the random walk to the Rt scale
Expand All @@ -132,8 +131,9 @@ class MyRt(metaclass.RandomVariable):
base_rv=base_rv,
transforms=transformation.ExpTransform(),
)
init_rt, *_ = rt_init_rv.sample()
return rt_rv(n_steps=n_steps, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
rtproc = MyRt(
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/day_of_the_week.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
.. Please do not edit this file directly.
.. This file is just a placeholder.
.. For the source file, see:
.. <https://github.com/CDCgov/multisignal-epi-inference/tree/main/docs/source/tutorials/day_of_the_week.qmd>
.. <https://github.com/CDCgov/PyRenew/tree/main/docs/source/tutorials/day_of_the_week.qmd>
55 changes: 37 additions & 18 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ import numpyro.distributions as dist
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import InfectionsWithFeedback
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.process import SimpleRandomWalkProcess
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.process import RandomWalk
from pyrenew.metaclass import (
RandomVariable,
DistributionalRV,
TransformedRandomVariable,
)
from pyrenew.latent import (
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
Expand Down Expand Up @@ -62,19 +66,31 @@ latent_infections = InfectionsWithFeedback(
infection_feedback_pmf=gen_int,
)
rt = TransformedRandomVariable(
"Rt_rv",
base_rv=SimpleRandomWalkProcess(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
init_rv=DistributionalRV(
class MyRt(RandomVariable):
def validate(self):
pass
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = TransformedRandomVariable(
name="log_rt_random_walk",
base_rv=RandomWalk(
name="log_rt",
step_rv=DistributionalRV(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=t.ExpTransform(),
)
rt_init_rv = DistributionalRV(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
),
),
transforms=t.ExpTransform(),
)
)
init_rt, *_ = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
```

With all the components defined, we can build the model:
Expand All @@ -85,7 +101,7 @@ model0 = RtInfectionsRenewalModel(
gen_int_rv=gen_int,
I0_rv=I0,
latent_infections_rv=latent_infections,
Rt_process_rv=rt,
Rt_process_rv=MyRt(),
infection_obs_process_rv=None,
)
```
Expand Down Expand Up @@ -209,10 +225,13 @@ class InfFeedback(RandomVariable):
inf_feedback_strength, *_ = self.infection_feedback_strength(
**kwargs,
)
inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength.value)
inf_feedback_strength = au.pad_x_to_match_y(
x=inf_feedback_strength.value,
x=inf_feedback_strength,
y=Rt,
fill_value=inf_feedback_strength.value[0],
fill_value=inf_feedback_strength[0],
)
# Sampling inf feedback and adjusting the shape
Expand Down Expand Up @@ -260,7 +279,7 @@ model1 = RtInfectionsRenewalModel(
gen_int_rv=gen_int,
I0_rv=I0,
latent_infections_rv=latent_infections2,
Rt_process_rv=rt,
Rt_process_rv=MyRt(),
infection_obs_process_rv=None,
)
Expand Down
27 changes: 11 additions & 16 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -186,37 +186,32 @@ gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int)
class MyRt(metaclass.RandomVariable):
def __init__(self, sd_rv):
self.sd_rv = sd_rv
def validate(self):
pass
def sample(self, n_steps: int, **kwargs) -> tuple:
sd_rt, *_ = self.sd_rv()
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = metaclass.TransformedRandomVariable(
"Rt_rv",
base_rv=process.SimpleRandomWalkProcess(
name="log_rt_random_walk",
base_rv=process.RandomWalk(
name="log_rt",
step_rv=metaclass.DistributionalRV(
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
),
init_rv=metaclass.DistributionalRV(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=transformation.ExpTransform(),
)
rt_init_rv = metaclass.DistributionalRV(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
return rt_rv.sample(n_steps=n_steps, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
rtproc = MyRt(
metaclass.DistributionalRV(
name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025)
)
)
rtproc = MyRt()
# The observation model
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from pyrenew import process, deterministic

```{python}
# The random process for Rt
rt_proc = process.RtWeeklyDiffProcess(
rt_proc = process.RtWeeklyDiffARProcess(
name="rt_weekly_diff",
offset=0,
log_rt_rv=deterministic.DeterministicVariable(
Expand Down Expand Up @@ -57,7 +57,7 @@ for i in range(0, 30, 7):
plt.show()
```

The implementation of the `RtWeeklyDiffProcess` (which is an instance of `RtPeriodicDiffProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.
The implementation of the `RtWeeklyDiffARProcess` (which is an instance of `RtPeriodicDiffARProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.

## Repeated sequences (tiling)

Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ sphinxcontrib-mermaid = "^0.9.2"
sphinx-autodoc-typehints = "^2.1.0"
sphinx-book-theme = "^1.1.2"
ipywidgets = "^8.1.3"
jupyter = "^1.0.0"

[tool.poetry.group.test]
optional = true
Expand All @@ -49,6 +50,7 @@ optional = true
pytest = "^8.3.2"
pytest-cov = "^5.0.0"
pytest-mpl = "^0.17.0"
scipy = "^1.14.1"

[tool.numpydoc_validation]
checks = [
Expand Down Expand Up @@ -81,4 +83,4 @@ build-backend = "poetry.core.masonry.api"
known_first_party = ["pyrenew", "test"]

[tool.deptry.per_rule_ignores]
DEP004 = ["pytest"]
DEP004 = ["pytest", "scipy"]
3 changes: 1 addition & 2 deletions src/pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from __future__ import annotations

import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike

Expand Down Expand Up @@ -42,7 +41,7 @@ def __init__(
"""
self.name = name
self.validate(value)
self.value = jnp.atleast_1d(value)
self.value = value
self.set_timeseries(t_start, t_unit)

return None
Expand Down
13 changes: 9 additions & 4 deletions src/pyrenew/latent/infection_initialization_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def initialize_infections(self, I_pre_init: ArrayLike):
ArrayLike
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
"""
I_pre_init = jnp.atleast_1d(I_pre_init)
if self.n_timepoints < I_pre_init.size:
raise ValueError(
"I_pre_init must be no longer than n_timepoints. "
Expand All @@ -105,20 +106,23 @@ def initialize_infections(self, I_pre_init: ArrayLike):
Parameters
----------
I_pre_init : ArrayLike
An array with the same length as ``n_timepoints`` to be used as the initial infections.
An array with the same length as ``n_timepoints`` to be
used as the initial infections.
Returns
-------
ArrayLike
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
An array of length ``n_timepoints`` with the number of
initialized infections at each time point.
"""
I_pre_init = jnp.array(I_pre_init)
if I_pre_init.size != self.n_timepoints:
raise ValueError(
"I_pre_init must have the same size as n_timepoints. "
f"Got I_pre_init of size {I_pre_init.size} "
f"and n_timepoints of size {self.n_timepoints}."
)
return jnp.array(I_pre_init)
return I_pre_init


class InitializeInfectionsExponentialGrowth(InfectionInitializationMethod):
Expand Down Expand Up @@ -173,11 +177,12 @@ def initialize_infections(self, I_pre_init: ArrayLike):
ArrayLike
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
"""
I_pre_init = jnp.array(I_pre_init)
if I_pre_init.size != 1:
raise ValueError(
f"I_pre_init must be an array of size 1. Got size {I_pre_init.size}."
)
rate = self.rate()[0].value
rate = jnp.array(self.rate()[0].value)
if rate.size != 1:
raise ValueError(
f"rate must be an array of size 1. Got size {rate.size}."
Expand Down
3 changes: 0 additions & 3 deletions src/pyrenew/latent/infection_initialization_process.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08
import numpyro

from pyrenew.latent.infection_initialization_method import (
InfectionInitializationMethod,
)
Expand Down Expand Up @@ -97,7 +95,6 @@ def sample(self) -> tuple:
infection_initialization = self.infection_init_method(
I_pre_init.value,
)
numpyro.deterministic(self.name, infection_initialization)

return (
SampledValue(
Expand Down
14 changes: 8 additions & 6 deletions src/pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ def sample(
I0 = I0[-gen_int_rev.size :]

# Sampling inf feedback strength
inf_feedback_strength = self.infection_feedback_strength(
**kwargs,
)[0].value

inf_feedback_strength = jnp.atleast_1d(
self.infection_feedback_strength(
**kwargs,
)[0].value
)
# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.size == 1:
inf_feedback_strength = au.pad_x_to_match_y(
Expand All @@ -173,8 +174,9 @@ def sample(
)
elif inf_feedback_strength.size != Rt.size:
raise ValueError(
"Infection feedback strength must be of size 1 or the same "
f"size as the reproduction number. Got {inf_feedback_strength.size} "
"Infection feedback strength must be of size 1 "
"or the same size as the reproduction number array. "
f"Got {inf_feedback_strength.size} "
f"and {Rt.size} respectively."
)

Expand Down
Loading

0 comments on commit a793417

Please sign in to comment.