Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Jul 31, 2024
1 parent c9f3239 commit 767f1bd
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
4 changes: 3 additions & 1 deletion pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def intervention_generator(
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == total_length
interventions: List[Dict[float, Dict[str, Intervention]]] = [None] * len(intervention_funcs)
interventions: List[Dict[float, Dict[str, Intervention]]] = [None] * len(
intervention_funcs
)
i = 0
for j, (input_length, intervention_func) in enumerate(
zip(intervention_func_lengths, intervention_funcs)
Expand Down
2 changes: 1 addition & 1 deletion pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
import pyro
Expand Down
9 changes: 5 additions & 4 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from pyciemss.integration_utils.intervention_builder import (
intervention_func_combinator,
param_value_objective,
start_time_objective,
start_time_param_value_objective,
Expand Down Expand Up @@ -142,10 +143,10 @@ def __init__(
optkwargs_SEIRHD_paramtimeComb_maxQoI = {
"qoi": lambda x: obs_max_qoi(x, ["I_state"]),
"risk_bound": 3e5,
"static_parameter_interventions": lambda x: [
static_parameter_interventions1(torch.atleast_1d(x[0])),
static_parameter_interventions2(torch.atleast_1d(x[1])),
],
"static_parameter_interventions": intervention_func_combinator(
[static_parameter_interventions1, static_parameter_interventions2],
[1, 1],
),
"objfun": lambda x: np.abs(0.35 - x[0]) - x[1],
"initial_guess_interventions": [0.35, 5.0],
"bounds_interventions": [[0.1, 1.0], [0.5, 90.0]],
Expand Down

0 comments on commit 767f1bd

Please sign in to comment.