From 767f1bdb3927e44c978ced31d83371b9ab9fbb5a Mon Sep 17 00:00:00 2001 From: Anirban Chaudhuri <75496534+anirban-chaudhuri@users.noreply.github.com> Date: Wed, 31 Jul 2024 00:35:37 -0400 Subject: [PATCH] typing --- pyciemss/integration_utils/intervention_builder.py | 4 +++- pyciemss/ouu/ouu.py | 2 +- tests/fixtures.py | 9 +++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pyciemss/integration_utils/intervention_builder.py b/pyciemss/integration_utils/intervention_builder.py index 94bfba75..a53c35ba 100644 --- a/pyciemss/integration_utils/intervention_builder.py +++ b/pyciemss/integration_utils/intervention_builder.py @@ -130,7 +130,9 @@ def intervention_generator( ) -> Dict[float, Dict[str, Intervention]]: x = torch.atleast_1d(x) assert x.size()[0] == total_length - interventions: List[Dict[float, Dict[str, Intervention]]] = [None] * len(intervention_funcs) + interventions: List[Dict[float, Dict[str, Intervention]]] = [None] * len( + intervention_funcs + ) i = 0 for j, (input_length, intervention_func) in enumerate( zip(intervention_func_lengths, intervention_funcs) diff --git a/pyciemss/ouu/ouu.py b/pyciemss/ouu/ouu.py index 4d528b54..df8f38a8 100644 --- a/pyciemss/ouu/ouu.py +++ b/pyciemss/ouu/ouu.py @@ -1,7 +1,7 @@ import contextlib import warnings from copy import deepcopy -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple import numpy as np import pyro diff --git a/tests/fixtures.py b/tests/fixtures.py index a51373ba..198d161c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -7,6 +7,7 @@ import torch from pyciemss.integration_utils.intervention_builder import ( + intervention_func_combinator, param_value_objective, start_time_objective, start_time_param_value_objective, @@ -142,10 +143,10 @@ def __init__( optkwargs_SEIRHD_paramtimeComb_maxQoI = { "qoi": lambda x: obs_max_qoi(x, ["I_state"]), "risk_bound": 3e5, - "static_parameter_interventions": lambda x: [ - static_parameter_interventions1(torch.atleast_1d(x[0])), - static_parameter_interventions2(torch.atleast_1d(x[1])), - ], + "static_parameter_interventions": intervention_func_combinator( + [static_parameter_interventions1, static_parameter_interventions2], + [1, 1], + ), "objfun": lambda x: np.abs(0.35 - x[0]) - x[1], "initial_guess_interventions": [0.35, 5.0], "bounds_interventions": [[0.1, 1.0], [0.5, 90.0]],