Skip to content

Commit

Permalink
resolves pr change requests.
Browse files Browse the repository at this point in the history
  • Loading branch information
azane committed Sep 11, 2024
1 parent cda2f56 commit 3de0de8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
6 changes: 2 additions & 4 deletions tests/dynamical/dynamical_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pyro
import torch
from pyro.distributions import Normal, Uniform, constraints
from torch import Tensor as Tnsr

from chirho.dynamical.ops import State

Expand All @@ -16,8 +15,8 @@

# SIR dynamics written as a pure function of state and parameters.
def pure_sir_dynamics(
state: State[Tnsr], atemp_params: ATempParams[Tnsr]
) -> State[Tnsr]:
state: State[torch.Tensor], atemp_params: ATempParams[torch.Tensor]
) -> State[torch.Tensor]:
beta = atemp_params["beta"]
gamma = atemp_params["gamma"]

Expand Down Expand Up @@ -50,7 +49,6 @@ def observation(self, X: State[torch.Tensor]):

class SIRReparamObservationMixin(SIRObservationMixin):
def observation(self, X: State[torch.Tensor]):
# super().observation(X)

# A flight arrives in a country that tests all arrivals for a disease. The number of people infected on the
# plane is a noisy function of the number of infected people in the country of origin at that time.
Expand Down
2 changes: 0 additions & 2 deletions tests/dynamical/test_handler_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@


def counterf_model(solver, model):
# model = UnifiedFixtureDynamicsReparam(beta=0.5, gamma=0.7)
obs = condition(data=flight_landing_data)(model.observation)
vec_obs3 = StaticBatchObservation(times=flight_landing_times, observation=obs)
with vec_obs3:
# with TorchDiffEq():
with solver():
with reparam, twin_world, intervention:
return simulate(
Expand Down
14 changes: 7 additions & 7 deletions tests/dynamical/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,26 @@
end_time = torch.tensor(4.0)


def test_no_backend_error():
def test_no_solver_error():
sir = bayes_sir_model()
with pytest.raises(NotImplementedError):
simulate(sir, init_state, start_time, end_time)


@pytest.mark.parametrize("backend", [TorchDiffEq])
@pytest.mark.parametrize("solver", [TorchDiffEq])
@pytest.mark.parametrize("dynamics", [bayes_sir_model()])
def test_backend_arg(backend, dynamics):
with backend():
def test_solver_arg(solver, dynamics):
with solver():
result = simulate(dynamics, init_state, start_time, end_time)
assert result is not None


@pytest.mark.parametrize("backend", [TorchDiffEq])
@pytest.mark.parametrize("solver", [TorchDiffEq])
@pytest.mark.parametrize("dynamics_builder", [bayes_sir_model])
def test_broadcasting(backend, dynamics_builder):
def test_broadcasting(solver, dynamics_builder):
with pyro.plate("plate", 3):
dynamics = dynamics_builder()
with backend():
with solver():
result = simulate(dynamics, init_state, start_time, end_time)

for v in result.values():
Expand Down

0 comments on commit 3de0de8

Please sign in to comment.