From 2399e954261fc84016cbb34e9faa48cb4ebd52ca Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Tue, 24 Oct 2023 10:56:14 -0400 Subject: [PATCH] ran make format --- src/pyciemss/Ensemble/interfaces.py | 44 +++--- src/pyciemss/Ensemble/interfaces_bigbox.py | 21 +-- src/pyciemss/ODE/interfaces.py | 144 +++++++++++------- src/pyciemss/ODE/interfaces_bigbox.py | 34 ++--- .../integration_utils/custom_decorators.py | 10 +- src/pyciemss/interfaces.py | 7 +- 6 files changed, 135 insertions(+), 125 deletions(-) diff --git a/src/pyciemss/Ensemble/interfaces.py b/src/pyciemss/Ensemble/interfaces.py index 0c46454ae..3051a2da6 100644 --- a/src/pyciemss/Ensemble/interfaces.py +++ b/src/pyciemss/Ensemble/interfaces.py @@ -1,40 +1,32 @@ +import copy +from collections.abc import Iterable +from typing import Callable, Optional, Tuple, Union + import pyro import torch - from pyro.infer import Predictive -from pyciemss.PetriNetODE.base import get_name -from pyciemss.interfaces import ( - setup_model, - reset_model, - intervene, - sample, - calibrate, - DynamicalSystem, - prepare_interchange_dictionary, - DEFAULT_QUANTILES, -) - -from pyciemss.utils import interface_utils - +from pyciemss.custom_decorators import pyciemss_logging_wrapper from pyciemss.Ensemble.base import ( EnsembleSystem, ScaledBetaNoiseEnsembleSystem, ScaledNormalNoiseEnsembleSystem, ) - -from typing import Optional, Tuple, Callable, Union -from collections.abc import Iterable -import copy - -# TODO: probably refactor this out later. -from pyciemss.PetriNetODE.events import ( - StartEvent, - ObservationEvent, - LoggingEvent, +from pyciemss.interfaces import ( + DEFAULT_QUANTILES, + DynamicalSystem, + calibrate, + intervene, + prepare_interchange_dictionary, + reset_model, + sample, + setup_model, ) +from pyciemss.PetriNetODE.base import get_name -from pyciemss.custom_decorators import pyciemss_logging_wrapper +# TODO: probably refactor this out later. +from pyciemss.PetriNetODE.events import LoggingEvent, ObservationEvent, StartEvent +from pyciemss.utils import interface_utils from pyciemss.utils.interface_utils import convert_to_output_format from pyciemss.visuals import plots diff --git a/src/pyciemss/Ensemble/interfaces_bigbox.py b/src/pyciemss/Ensemble/interfaces_bigbox.py index c0cc66f50..625614d45 100644 --- a/src/pyciemss/Ensemble/interfaces_bigbox.py +++ b/src/pyciemss/Ensemble/interfaces_bigbox.py @@ -1,26 +1,19 @@ -from typing_extensions import deprecated -import pyro +from typing import Iterable, Optional, Union import mira +import pyro +from typing_extensions import deprecated -from ..interfaces import DEFAULT_QUANTILES -from .interfaces import ( - setup_model, - sample, - calibrate, - prepare_interchange_dictionary, -) - +from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper from pyciemss.ODE.base import get_name from pyciemss.ODE.interfaces import load_petri_model - from pyciemss.utils.interface_utils import ( - csv_to_list, create_mapping_function_from_observables, + csv_to_list, ) -from typing import Iterable, Optional, Union -from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper +from ..interfaces import DEFAULT_QUANTILES +from .interfaces import calibrate, prepare_interchange_dictionary, sample, setup_model @deprecated( diff --git a/src/pyciemss/ODE/interfaces.py b/src/pyciemss/ODE/interfaces.py index 83d8a9dd1..e52bb9174 100644 --- a/src/pyciemss/ODE/interfaces.py +++ b/src/pyciemss/ODE/interfaces.py @@ -1,52 +1,47 @@ -import pyro -import torch -import time -import numpy as np -from math import ceil -from typing import Iterable, Optional, Tuple, Union, Callable import copy -import warnings - import random as rand +import time +import warnings +from math import ceil +from typing import Callable, Iterable, Optional, Tuple, Union -from torch.distributions import biject_to - +import mira +import numpy as np +import pyro +import torch from pyro.infer import Predictive from pyro.infer.autoguide import AutoLowRankMultivariateNormal +from torch.distributions import biject_to -from pyciemss.ODE.base import ( - PetriNetODESystem, - ScaledNormalNoisePetriNetODESystem, - ScaledBetaNoisePetriNetODESystem, - get_name, -) -from pyciemss.risk.ouu import computeRisk, solveOUU -from pyciemss.risk.risk_measures import alpha_superquantile -from pyciemss.utils.interface_utils import convert_to_output_format -from pyciemss.visuals import plots - -import mira +from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper # Load base interfaces from pyciemss.interfaces import ( - setup_model, - reset_model, - intervene, - sample, + DEFAULT_QUANTILES, calibrate, + intervene, optimize, prepare_interchange_dictionary, - DEFAULT_QUANTILES + reset_model, + sample, + setup_model, +) +from pyciemss.ODE.base import ( + PetriNetODESystem, + ScaledBetaNoisePetriNetODESystem, + ScaledNormalNoisePetriNetODESystem, + get_name, ) - from pyciemss.ODE.events import ( - StartEvent, - ObservationEvent, LoggingEvent, + ObservationEvent, + StartEvent, StaticParameterInterventionEvent, ) - -from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper +from pyciemss.risk.ouu import computeRisk, solveOUU +from pyciemss.risk.risk_measures import alpha_superquantile +from pyciemss.utils.interface_utils import convert_to_output_format +from pyciemss.visuals import plots # TODO: These interfaces should probably be just in terms of JSON-like objects. @@ -60,21 +55,31 @@ def load_petri_model( noise_model: str = "scaled_normal", noise_scale: float = 0.1, compile_observables_p: bool = False, - compile_rate_law_p: bool = False + compile_rate_law_p: bool = False, ) -> PetriNetODESystem: """ Load a petri net from a file and compile it into a probabilistic program. """ if noise_model == "scaled_beta": return ScaledBetaNoisePetriNetODESystem.from_askenet( - petri_model_or_path, noise_scale=noise_scale, compile_rate_law_p=compile_rate_law_p, compile_observables_p=compile_observables_p, add_uncertainty=add_uncertainty + petri_model_or_path, + noise_scale=noise_scale, + compile_rate_law_p=compile_rate_law_p, + compile_observables_p=compile_observables_p, + add_uncertainty=add_uncertainty, ) elif noise_model == "scaled_normal": return ScaledNormalNoisePetriNetODESystem.from_askenet( - petri_model_or_path, noise_scale=noise_scale, compile_rate_law_p=compile_rate_law_p, compile_observables_p=compile_observables_p, add_uncertainty=add_uncertainty + petri_model_or_path, + noise_scale=noise_scale, + compile_rate_law_p=compile_rate_law_p, + compile_observables_p=compile_observables_p, + add_uncertainty=add_uncertainty, ) else: - raise ValueError(f"Unknown noise model {noise_model}. Please select from either 'scaled_beta' or 'scaled_normal'.") + raise ValueError( + f"Unknown noise model {noise_model}. Please select from either 'scaled_beta' or 'scaled_normal'." + ) @setup_model.register @@ -91,7 +96,7 @@ def setup_petri_model( start_state = { get_name(v): v.data["initial_value"] for v in petri.G.variables.values() } - + # TODO: Figure out how to do this without copying the petri net. start_event = StartEvent(start_time, start_state) new_petri = copy.deepcopy(petri) @@ -114,20 +119,25 @@ def reset_petri_model(petri: PetriNetODESystem) -> PetriNetODESystem: @intervene.register @pyciemss_logging_wrapper def intervene_petri_model( - petri: PetriNetODESystem, interventions: Iterable[Tuple[float, str, float]], jostle_scale: float = 1e-5 + petri: PetriNetODESystem, + interventions: Iterable[Tuple[float, str, float]], + jostle_scale: float = 1e-5, ) -> PetriNetODESystem: """ Intervene on a model. """ # Note: this will have to change if we want to add more sophisticated interventions. interventions = [ - StaticParameterInterventionEvent(timepoint + (0.1+rand.random())*jostle_scale, parameter, value) + StaticParameterInterventionEvent( + timepoint + (0.1 + rand.random()) * jostle_scale, parameter, value + ) for timepoint, parameter, value in interventions ] new_petri = copy.deepcopy(petri) new_petri.load_events(interventions) return new_petri + @calibrate.register @pyciemss_logging_wrapper def calibrate_petri( @@ -145,10 +155,11 @@ def calibrate_petri( """ Use variational inference with a mean-field variational family to infer the parameters of the model. """ - + new_petri = copy.deepcopy(petri) observations = [ - ObservationEvent(timepoint + (0.1+rand.random()) * jostle_scale, observation) for timepoint, observation in data + ObservationEvent(timepoint + (0.1 + rand.random()) * jostle_scale, observation) + for timepoint, observation in data ] for obs in observations: @@ -156,8 +167,10 @@ def calibrate_petri( for v in obs.observation.values(): s += v if not 0 <= v <= petri.total_population: - warnings.warn(f"Observation {obs} is not in the range [0, {petri.total_population}]. This may be an error!") - #assert s <= petri.total_population or torch.isclose(s, petri.total_population) + warnings.warn( + f"Observation {obs} is not in the range [0, {petri.total_population}]. This may be an error!" + ) + # assert s <= petri.total_population or torch.isclose(s, petri.total_population) new_petri.load_events(observations) guide = autoguide(new_petri) @@ -168,7 +181,7 @@ def calibrate_petri( pyro.clear_param_store() for i in range(num_iterations): - progress_hook(i/num_iterations) + progress_hook(i / num_iterations) loss = svi.step(method=method) if verbose: if i % 25 == 0: @@ -178,29 +191,35 @@ def calibrate_petri( @pyciemss_logging_wrapper -def get_posterior_density_mesh_petri(inferred_parameters: PetriInferredParameters, - mesh_params: Optional[dict[str, list[float]]]) -> float: +def get_posterior_density_mesh_petri( + inferred_parameters: PetriInferredParameters, + mesh_params: Optional[dict[str, list[float]]], +) -> float: """ Compute the log posterior density of the inferred parameters at the given parameter values. Args: inferred_parameters: PetriInferredParameters - The inferred parameters from the calibration. mesh_params: dict[str, list] - - Parameter values used to compute a mesh of sample points. + - Parameter values used to compute a mesh of sample points. Keys are parameter names, values are (min, max, steps) parameters passed to linspace. Returns: log_density: float - The log posterior density of the inferred parameters at the given parameter values. - """ + """ spaces = [torch.linspace(*params) for params in mesh_params.values()] - parameter_values = dict(zip(mesh_params.keys(), torch.meshgrid(*spaces, indexing='ij'))) + parameter_values = dict( + zip(mesh_params.keys(), torch.meshgrid(*spaces, indexing="ij")) + ) density = get_posterior_density_petri(inferred_parameters, parameter_values) return parameter_values, density @pyciemss_logging_wrapper -def get_posterior_density_petri(inferred_parameters: PetriInferredParameters, - parameter_values: dict[str, Union[list[float], torch.tensor]]) -> float: +def get_posterior_density_petri( + inferred_parameters: PetriInferredParameters, + parameter_values: dict[str, Union[list[float], torch.tensor]], +) -> float: """ Compute the log posterior density of the inferred parameters at the given parameter values. Args: @@ -213,24 +232,36 @@ def get_posterior_density_petri(inferred_parameters: PetriInferredParameters, - The log posterior density of the inferred parameters at the given parameter values. """ - guides = [guide for guide in inferred_parameters if type(guide) == AutoLowRankMultivariateNormal] + guides = [ + guide + for guide in inferred_parameters + if type(guide) == AutoLowRankMultivariateNormal + ] # By construction there should be only a single AutoLowRankMultivariateNormal guide. The rest should be AutoDeltas. if len(guides) != 1: - raise ValueError(f"Expected a single AutoLowRankMultivariateNormal guide, but found {len(guides)} guides.") + raise ValueError( + f"Expected a single AutoLowRankMultivariateNormal guide, but found {len(guides)} guides." + ) guide = guides[0] # For now we only support density evaluation on the full parameter space. if guide.loc.shape[0] != len(parameter_values): - raise ValueError(f"Expected {guide.loc.shape[0]} parameters, but found {len(parameter_values)} parameters.") + raise ValueError( + f"Expected {guide.loc.shape[0]} parameters, but found {len(parameter_values)} parameters." + ) - parameter_values = {name: torch.as_tensor(value) for name, value in parameter_values.items()} + parameter_values = { + name: torch.as_tensor(value) for name, value in parameter_values.items() + } # Assert that all of the parameters in the `parameter_values` are the same size. parameter_sizes = set([value.size() for value in parameter_values.values()]) if len(parameter_sizes) != 1: - raise ValueError(f"Expected all parameter values to have the same size, but found {len(parameter_sizes)} distinct sizes.") + raise ValueError( + f"Expected all parameter values to have the same size, but found {len(parameter_sizes)} distinct sizes." + ) parameter_size = parameter_sizes.pop() @@ -248,6 +279,7 @@ def get_posterior_density_petri(inferred_parameters: PetriInferredParameters, return torch.exp(log_density).detach() + @sample.register @pyciemss_logging_wrapper def sample_petri( diff --git a/src/pyciemss/ODE/interfaces_bigbox.py b/src/pyciemss/ODE/interfaces_bigbox.py index 03c6db469..706ce33ba 100644 --- a/src/pyciemss/ODE/interfaces_bigbox.py +++ b/src/pyciemss/ODE/interfaces_bigbox.py @@ -1,35 +1,29 @@ -from typing_extensions import deprecated -import pyro -import numpy as np -from typing import Iterable, Optional, Tuple, Union, Callable import copy +from typing import Callable, Iterable, Optional, Tuple, Union -from pyro.infer.autoguide import AutoDelta, AutoLowRankMultivariateNormal, AutoGuideList +import mira +import numpy as np +import pyro +from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoLowRankMultivariateNormal +from typing_extensions import deprecated -from pyciemss.ODE.base import ( - get_name, -) +import pyciemss.risk.qoi +from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper +from pyciemss.ODE.base import get_name from pyciemss.risk.ouu import computeRisk from pyciemss.risk.risk_measures import alpha_superquantile -import pyciemss.risk.qoi from pyciemss.utils.interface_utils import csv_to_list -import mira - - -from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper - +from ..interfaces import DEFAULT_QUANTILES from .interfaces import ( - load_petri_model, - setup_model, - sample, calibrate, - optimize, intervene, + load_petri_model, + optimize, prepare_interchange_dictionary, + sample, + setup_model, ) -from ..interfaces import DEFAULT_QUANTILES - # TODO: These interfaces should probably be just in terms of JSON-like objects. diff --git a/src/pyciemss/integration_utils/custom_decorators.py b/src/pyciemss/integration_utils/custom_decorators.py index c7b2bd0f2..55b760b96 100644 --- a/src/pyciemss/integration_utils/custom_decorators.py +++ b/src/pyciemss/integration_utils/custom_decorators.py @@ -1,20 +1,19 @@ +import functools import logging import time -import functools -def pyciemss_logging_wrapper( function): + +def pyciemss_logging_wrapper(function): def wrapped(*args, **kwargs): try: start_time = time.perf_counter() result = function(*args, **kwargs) end_time = time.perf_counter() logging.info( - "Elapsed time for %s: %f", - function.__name__, end_time - start_time + "Elapsed time for %s: %f", function.__name__, end_time - start_time ) return result except Exception as e: - log_message = """ ############################### @@ -28,5 +27,6 @@ def wrapped(*args, **kwargs): """ logging.exception(log_message, function.__name__, function.__doc__) raise e + functools.update_wrapper(wrapped, function) return wrapped diff --git a/src/pyciemss/interfaces.py b/src/pyciemss/interfaces.py index c294a1b8d..5f8af875d 100644 --- a/src/pyciemss/interfaces.py +++ b/src/pyciemss/interfaces.py @@ -1,7 +1,7 @@ -import pyro - -from typing import TypeVar, Optional, Iterable, Union import functools +from typing import Iterable, Optional, TypeVar, Union + +import pyro # Declare types # Note: this doesn't really do anything. More of a placeholder for how derived classes should be declared. @@ -173,7 +173,6 @@ def optimize( raise NotImplementedError - @functools.singledispatch def prepare_interchange_dictionary( samples: Simulation,