Skip to content

Commit

Permalink
ran make format
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty committed Oct 24, 2023
1 parent c6c624a commit 2399e95
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 125 deletions.
44 changes: 18 additions & 26 deletions src/pyciemss/Ensemble/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,32 @@
import copy
from collections.abc import Iterable
from typing import Callable, Optional, Tuple, Union

import pyro
import torch

from pyro.infer import Predictive

from pyciemss.PetriNetODE.base import get_name
from pyciemss.interfaces import (
setup_model,
reset_model,
intervene,
sample,
calibrate,
DynamicalSystem,
prepare_interchange_dictionary,
DEFAULT_QUANTILES,
)

from pyciemss.utils import interface_utils

from pyciemss.custom_decorators import pyciemss_logging_wrapper
from pyciemss.Ensemble.base import (
EnsembleSystem,
ScaledBetaNoiseEnsembleSystem,
ScaledNormalNoiseEnsembleSystem,
)

from typing import Optional, Tuple, Callable, Union
from collections.abc import Iterable
import copy

# TODO: probably refactor this out later.
from pyciemss.PetriNetODE.events import (
StartEvent,
ObservationEvent,
LoggingEvent,
from pyciemss.interfaces import (
DEFAULT_QUANTILES,
DynamicalSystem,
calibrate,
intervene,
prepare_interchange_dictionary,
reset_model,
sample,
setup_model,
)
from pyciemss.PetriNetODE.base import get_name

from pyciemss.custom_decorators import pyciemss_logging_wrapper
# TODO: probably refactor this out later.
from pyciemss.PetriNetODE.events import LoggingEvent, ObservationEvent, StartEvent
from pyciemss.utils import interface_utils
from pyciemss.utils.interface_utils import convert_to_output_format
from pyciemss.visuals import plots

Expand Down
21 changes: 7 additions & 14 deletions src/pyciemss/Ensemble/interfaces_bigbox.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
from typing_extensions import deprecated
import pyro
from typing import Iterable, Optional, Union

import mira
import pyro
from typing_extensions import deprecated

from ..interfaces import DEFAULT_QUANTILES
from .interfaces import (
setup_model,
sample,
calibrate,
prepare_interchange_dictionary,
)

from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper
from pyciemss.ODE.base import get_name
from pyciemss.ODE.interfaces import load_petri_model

from pyciemss.utils.interface_utils import (
csv_to_list,
create_mapping_function_from_observables,
csv_to_list,
)

from typing import Iterable, Optional, Union
from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper
from ..interfaces import DEFAULT_QUANTILES
from .interfaces import calibrate, prepare_interchange_dictionary, sample, setup_model


@deprecated(
Expand Down
144 changes: 88 additions & 56 deletions src/pyciemss/ODE/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,47 @@
import pyro
import torch
import time
import numpy as np
from math import ceil
from typing import Iterable, Optional, Tuple, Union, Callable
import copy
import warnings

import random as rand
import time
import warnings
from math import ceil
from typing import Callable, Iterable, Optional, Tuple, Union

from torch.distributions import biject_to

import mira
import numpy as np
import pyro
import torch
from pyro.infer import Predictive
from pyro.infer.autoguide import AutoLowRankMultivariateNormal
from torch.distributions import biject_to

from pyciemss.ODE.base import (
PetriNetODESystem,
ScaledNormalNoisePetriNetODESystem,
ScaledBetaNoisePetriNetODESystem,
get_name,
)
from pyciemss.risk.ouu import computeRisk, solveOUU
from pyciemss.risk.risk_measures import alpha_superquantile
from pyciemss.utils.interface_utils import convert_to_output_format
from pyciemss.visuals import plots

import mira
from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper

# Load base interfaces
from pyciemss.interfaces import (
setup_model,
reset_model,
intervene,
sample,
DEFAULT_QUANTILES,
calibrate,
intervene,
optimize,
prepare_interchange_dictionary,
DEFAULT_QUANTILES
reset_model,
sample,
setup_model,
)
from pyciemss.ODE.base import (
PetriNetODESystem,
ScaledBetaNoisePetriNetODESystem,
ScaledNormalNoisePetriNetODESystem,
get_name,
)

from pyciemss.ODE.events import (
StartEvent,
ObservationEvent,
LoggingEvent,
ObservationEvent,
StartEvent,
StaticParameterInterventionEvent,
)

from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper
from pyciemss.risk.ouu import computeRisk, solveOUU
from pyciemss.risk.risk_measures import alpha_superquantile
from pyciemss.utils.interface_utils import convert_to_output_format
from pyciemss.visuals import plots

# TODO: These interfaces should probably be just in terms of JSON-like objects.

Expand All @@ -60,21 +55,31 @@ def load_petri_model(
noise_model: str = "scaled_normal",
noise_scale: float = 0.1,
compile_observables_p: bool = False,
compile_rate_law_p: bool = False
compile_rate_law_p: bool = False,
) -> PetriNetODESystem:
"""
Load a petri net from a file and compile it into a probabilistic program.
"""
if noise_model == "scaled_beta":
return ScaledBetaNoisePetriNetODESystem.from_askenet(
petri_model_or_path, noise_scale=noise_scale, compile_rate_law_p=compile_rate_law_p, compile_observables_p=compile_observables_p, add_uncertainty=add_uncertainty
petri_model_or_path,
noise_scale=noise_scale,
compile_rate_law_p=compile_rate_law_p,
compile_observables_p=compile_observables_p,
add_uncertainty=add_uncertainty,
)
elif noise_model == "scaled_normal":
return ScaledNormalNoisePetriNetODESystem.from_askenet(
petri_model_or_path, noise_scale=noise_scale, compile_rate_law_p=compile_rate_law_p, compile_observables_p=compile_observables_p, add_uncertainty=add_uncertainty
petri_model_or_path,
noise_scale=noise_scale,
compile_rate_law_p=compile_rate_law_p,
compile_observables_p=compile_observables_p,
add_uncertainty=add_uncertainty,
)
else:
raise ValueError(f"Unknown noise model {noise_model}. Please select from either 'scaled_beta' or 'scaled_normal'.")
raise ValueError(
f"Unknown noise model {noise_model}. Please select from either 'scaled_beta' or 'scaled_normal'."
)


@setup_model.register
Expand All @@ -91,7 +96,7 @@ def setup_petri_model(
start_state = {
get_name(v): v.data["initial_value"] for v in petri.G.variables.values()
}

# TODO: Figure out how to do this without copying the petri net.
start_event = StartEvent(start_time, start_state)
new_petri = copy.deepcopy(petri)
Expand All @@ -114,20 +119,25 @@ def reset_petri_model(petri: PetriNetODESystem) -> PetriNetODESystem:
@intervene.register
@pyciemss_logging_wrapper
def intervene_petri_model(
petri: PetriNetODESystem, interventions: Iterable[Tuple[float, str, float]], jostle_scale: float = 1e-5
petri: PetriNetODESystem,
interventions: Iterable[Tuple[float, str, float]],
jostle_scale: float = 1e-5,
) -> PetriNetODESystem:
"""
Intervene on a model.
"""
# Note: this will have to change if we want to add more sophisticated interventions.
interventions = [
StaticParameterInterventionEvent(timepoint + (0.1+rand.random())*jostle_scale, parameter, value)
StaticParameterInterventionEvent(
timepoint + (0.1 + rand.random()) * jostle_scale, parameter, value
)
for timepoint, parameter, value in interventions
]
new_petri = copy.deepcopy(petri)
new_petri.load_events(interventions)
return new_petri


@calibrate.register
@pyciemss_logging_wrapper
def calibrate_petri(
Expand All @@ -145,19 +155,22 @@ def calibrate_petri(
"""
Use variational inference with a mean-field variational family to infer the parameters of the model.
"""

new_petri = copy.deepcopy(petri)
observations = [
ObservationEvent(timepoint + (0.1+rand.random()) * jostle_scale, observation) for timepoint, observation in data
ObservationEvent(timepoint + (0.1 + rand.random()) * jostle_scale, observation)
for timepoint, observation in data
]

for obs in observations:
s = 0.0
for v in obs.observation.values():
s += v
if not 0 <= v <= petri.total_population:
warnings.warn(f"Observation {obs} is not in the range [0, {petri.total_population}]. This may be an error!")
#assert s <= petri.total_population or torch.isclose(s, petri.total_population)
warnings.warn(
f"Observation {obs} is not in the range [0, {petri.total_population}]. This may be an error!"
)
# assert s <= petri.total_population or torch.isclose(s, petri.total_population)
new_petri.load_events(observations)

guide = autoguide(new_petri)
Expand All @@ -168,7 +181,7 @@ def calibrate_petri(
pyro.clear_param_store()

for i in range(num_iterations):
progress_hook(i/num_iterations)
progress_hook(i / num_iterations)
loss = svi.step(method=method)
if verbose:
if i % 25 == 0:
Expand All @@ -178,29 +191,35 @@ def calibrate_petri(


@pyciemss_logging_wrapper
def get_posterior_density_mesh_petri(inferred_parameters: PetriInferredParameters,
mesh_params: Optional[dict[str, list[float]]]) -> float:
def get_posterior_density_mesh_petri(
inferred_parameters: PetriInferredParameters,
mesh_params: Optional[dict[str, list[float]]],
) -> float:
"""
Compute the log posterior density of the inferred parameters at the given parameter values.
Args:
inferred_parameters: PetriInferredParameters
- The inferred parameters from the calibration.
mesh_params: dict[str, list]
- Parameter values used to compute a mesh of sample points.
- Parameter values used to compute a mesh of sample points.
Keys are parameter names, values are (min, max, steps) parameters passed to linspace.
Returns:
log_density: float
- The log posterior density of the inferred parameters at the given parameter values.
"""
"""
spaces = [torch.linspace(*params) for params in mesh_params.values()]
parameter_values = dict(zip(mesh_params.keys(), torch.meshgrid(*spaces, indexing='ij')))
parameter_values = dict(
zip(mesh_params.keys(), torch.meshgrid(*spaces, indexing="ij"))
)
density = get_posterior_density_petri(inferred_parameters, parameter_values)
return parameter_values, density


@pyciemss_logging_wrapper
def get_posterior_density_petri(inferred_parameters: PetriInferredParameters,
parameter_values: dict[str, Union[list[float], torch.tensor]]) -> float:
def get_posterior_density_petri(
inferred_parameters: PetriInferredParameters,
parameter_values: dict[str, Union[list[float], torch.tensor]],
) -> float:
"""
Compute the log posterior density of the inferred parameters at the given parameter values.
Args:
Expand All @@ -213,24 +232,36 @@ def get_posterior_density_petri(inferred_parameters: PetriInferredParameters,
- The log posterior density of the inferred parameters at the given parameter values.
"""

guides = [guide for guide in inferred_parameters if type(guide) == AutoLowRankMultivariateNormal]
guides = [
guide
for guide in inferred_parameters
if type(guide) == AutoLowRankMultivariateNormal
]

# By construction there should be only a single AutoLowRankMultivariateNormal guide. The rest should be AutoDeltas.
if len(guides) != 1:
raise ValueError(f"Expected a single AutoLowRankMultivariateNormal guide, but found {len(guides)} guides.")
raise ValueError(
f"Expected a single AutoLowRankMultivariateNormal guide, but found {len(guides)} guides."
)

guide = guides[0]

# For now we only support density evaluation on the full parameter space.
if guide.loc.shape[0] != len(parameter_values):
raise ValueError(f"Expected {guide.loc.shape[0]} parameters, but found {len(parameter_values)} parameters.")
raise ValueError(
f"Expected {guide.loc.shape[0]} parameters, but found {len(parameter_values)} parameters."
)

parameter_values = {name: torch.as_tensor(value) for name, value in parameter_values.items()}
parameter_values = {
name: torch.as_tensor(value) for name, value in parameter_values.items()
}

# Assert that all of the parameters in the `parameter_values` are the same size.
parameter_sizes = set([value.size() for value in parameter_values.values()])
if len(parameter_sizes) != 1:
raise ValueError(f"Expected all parameter values to have the same size, but found {len(parameter_sizes)} distinct sizes.")
raise ValueError(
f"Expected all parameter values to have the same size, but found {len(parameter_sizes)} distinct sizes."
)

parameter_size = parameter_sizes.pop()

Expand All @@ -248,6 +279,7 @@ def get_posterior_density_petri(inferred_parameters: PetriInferredParameters,

return torch.exp(log_density).detach()


@sample.register
@pyciemss_logging_wrapper
def sample_petri(
Expand Down
Loading

0 comments on commit 2399e95

Please sign in to comment.