Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise basic SIR tutorial #377

Merged
merged 46 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5bb60cc
changes dynamical systems intro to be more short-term focussed
SamWitty Nov 13, 2023
d164bb8
suppress warnings
SamWitty Nov 13, 2023
55f35ff
remove counterfactual demo
SamWitty Nov 13, 2023
5690500
Merge branch 'master' of https://github.com/BasisResearch/causal_pyro…
SamWitty Nov 28, 2023
7d21af6
remove unused imports and reorder imports
SamWitty Nov 28, 2023
f5b6d24
add failing LogTrajectory test
SamWitty Nov 28, 2023
9fd7c52
revise tests to exercise start and end time collisions
SamWitty Nov 28, 2023
6b3f985
removed unnecessary imports from test
SamWitty Nov 28, 2023
1cedd1e
Merge branch 'master' of https://github.com/BasisResearch/causal_pyro…
SamWitty Nov 28, 2023
b479d32
Merge branch 'master' of https://github.com/BasisResearch/causal_pyro…
SamWitty Nov 28, 2023
dc69b91
much simpler implementation
SamWitty Nov 30, 2023
1abc181
lint and comment
SamWitty Nov 30, 2023
ae4d013
added some functional indirection to appease linter
SamWitty Nov 30, 2023
2693668
lint
SamWitty Nov 30, 2023
9c2e936
type refinement
SamWitty Nov 30, 2023
f513a0e
nit about arg unpacking order
SamWitty Nov 30, 2023
053d047
Merge branch 'master' of https://github.com/BasisResearch/causal_pyro…
SamWitty Dec 4, 2023
fe666b9
Merge branch 'sw-time-collision' of https://github.com/BasisResearch/…
SamWitty Dec 4, 2023
ff0a626
added multiple simulate handling
SamWitty Dec 4, 2023
1b84677
remove commented stop
SamWitty Dec 4, 2023
6b20372
lint
SamWitty Dec 4, 2023
f2132ce
made BatchObservation handler use a continuation to guarantee it's ap…
SamWitty Dec 4, 2023
ce605a0
lint
SamWitty Dec 4, 2023
14ad3dc
Merge branch 'master' into sw-revise-SIR-tutorial
SamWitty Dec 4, 2023
e81ee0b
Merge branch 'sw-time-collision' of https://github.com/BasisResearch/…
SamWitty Dec 4, 2023
90213ee
add dynamical intro notebook to CI build
SamWitty Dec 4, 2023
12ab235
add CI test parameters
SamWitty Dec 4, 2023
e14dbb6
add a bunch of textual content
SamWitty Dec 4, 2023
5773fd3
Added description of example
SamWitty Dec 5, 2023
cd5ed6d
fixed bug in inference and added some plot changes
SamWitty Dec 5, 2023
e89ea65
fixed bug in inference and added some plot changes
SamWitty Dec 5, 2023
ef25bf4
updated plots
SamWitty Dec 5, 2023
ddbb126
updated plots
SamWitty Dec 5, 2023
3802246
more plot updates
SamWitty Dec 5, 2023
e0077db
more plot updates
SamWitty Dec 5, 2023
da80e59
finished first pass of text, will edit tomorrow morning
SamWitty Dec 5, 2023
78f1d9f
reran and minor edits
SamWitty Dec 5, 2023
54cbfb2
text edits
SamWitty Dec 6, 2023
c732ae6
add glossary
SamWitty Dec 6, 2023
440d227
Merge branch 'master' into sw-revise-SIR-tutorial
SamWitty Dec 7, 2023
3b6a4f6
Update solver.py
SamWitty Dec 7, 2023
c0144a8
previous commit fix
SamWitty Dec 7, 2023
cf913c5
add dynamical systems dependencies to CI workflow
SamWitty Dec 7, 2023
0394737
Merge branch 'sw-revise-SIR-tutorial' of https://github.com/BasisRese…
SamWitty Dec 7, 2023
9ede4ce
addressing comments
SamWitty Dec 7, 2023
d5ada3a
reran notebook
SamWitty Dec 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: Install Python packages from requirements.txt
run: |
pip install --upgrade pip
pip install -e .[test]
pip install -e .[test,dynamical]

- name: Run Notebook Test
run: |
Expand Down
3 changes: 2 additions & 1 deletion chirho/dynamical/handlers/interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def __init__(
self,
times: torch.Tensor,
observation: Observation[State[T]],
**kwargs,
):
self.observation = observation
super().__init__(times)
super().__init__(times, **kwargs)

def _pyro_post_simulate(self, msg: dict) -> None:
super()._pyro_post_simulate(msg)
Expand Down
7 changes: 6 additions & 1 deletion chirho/dynamical/handlers/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ class LogTrajectory(Generic[T], pyro.poutine.messenger.Messenger):
trajectory: State[T]
_trajectory: State[T]

def __init__(self, times: torch.Tensor):
def __init__(self, times: torch.Tensor, is_traced: bool = False):
self.times = times
self._trajectory: State[T] = State()
self.is_traced = is_traced

# Require that the times are sorted. This is required by the index masking we do below.
if not torch.all(self.times[1:] > self.times[:-1]):
Expand All @@ -46,6 +47,10 @@ def _pyro_post_simulate(self, msg) -> None:
self.trajectory = self._trajectory
self._trajectory: State[T] = State()

if self.is_traced:
# This adds the trajectory to the trace so that it can be accessed later.
[pyro.deterministic(name, value) for name, value in self.trajectory.items()]

def _pyro_simulate_point(self, msg) -> None:
# Turn a simulate that returns a state into a simulate that returns a trajectory at each of the logging_times
dynamics, initial_state, start_time, end_time = msg["args"]
Expand Down
Loading
Loading