Skip to content

Commit

Permalink
fixing input lists to intervention templates
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Aug 8, 2024
1 parent b8e9b18 commit 02ed976
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 20 deletions.
8 changes: 4 additions & 4 deletions docs/source/optimize_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1104,11 +1104,11 @@
"intervened_params = [\"beta_c\", \"gamma\", \"gamma_c\"]\n",
"static_parameter_interventions1 = param_value_objective(\n",
" param_name=[intervened_params[0]],\n",
" start_time=torch.tensor(10.0),\n",
" start_time=[torch.tensor(10.0)],\n",
")\n",
"static_parameter_interventions2 = start_time_objective(\n",
" param_name=[intervened_params[1]],\n",
" param_value=torch.tensor([0.4]),\n",
" param_value=[torch.tensor([0.4])],\n",
")\n",
"\n",
"# Combine different intervention templates\n",
Expand Down Expand Up @@ -1270,11 +1270,11 @@
"intervened_params = [\"beta_c\", \"gamma\"]\n",
"static_parameter_interventions1 = param_value_objective(\n",
" param_name=[intervened_params[0]],\n",
" start_time=torch.tensor(10.0),\n",
" start_time=[torch.tensor(10.0)],\n",
")\n",
"static_parameter_interventions2 = start_time_objective(\n",
" param_name=[intervened_params[1]],\n",
" param_value=torch.tensor([0.45]),\n",
" param_value=[torch.tensor([0.45])],\n",
")\n",
"# Combine different intervention templates into a list of Callables\n",
"static_parameter_interventions = intervention_func_combinator(\n",
Expand Down
25 changes: 11 additions & 14 deletions pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ def intervention_generator(
for count in range(param_size):
if start_time[count].item() in static_parameter_interventions:
static_parameter_interventions[start_time[count].item()].update(
{param_name[count]: param_value[count](x[count].item()).item()}
{param_name[count]: param_value[count](x[count].item())}
)
else:
static_parameter_interventions.update(
{
start_time[count].item(): {
param_name[count]: param_value[count](
x[count].item()
).item()
param_name[count]: param_value[count](x[count].item())
}
}
)
Expand All @@ -52,7 +50,7 @@ def start_time_objective(
param_value: List[Intervention],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
param_size = len(param_name)
# Note that code below will only work for tensors and not callable functions
# Note: code below will only work for tensors and not callable functions
param_value = [torch.atleast_1d(y) for y in param_value]

def intervention_generator(
Expand All @@ -67,11 +65,11 @@ def intervention_generator(
for count in range(param_size):
if x[count].item() in static_parameter_interventions:
static_parameter_interventions[x[count].item()].update(
{param_name[count]: param_value[count].item()}
{param_name[count]: param_value[count]}
)
else:
static_parameter_interventions.update(
{x[count].item(): {param_name[count]: param_value[count].item()}}
{x[count].item(): {param_name[count]: param_value[count]}}
)
return static_parameter_interventions

Expand Down Expand Up @@ -102,19 +100,15 @@ def intervention_generator(
for count in range(param_size):
if x[count * 2].item() in static_parameter_interventions:
static_parameter_interventions[x[count * 2].item()].update(
{
param_name[count]: param_value[count](
x[count * 2 + 1].item()
).item()
}
{param_name[count]: param_value[count](x[count * 2 + 1].item())}
)
else:
static_parameter_interventions.update(
{
x[count * 2].item(): {
param_name[count]: param_value[count](
x[count * 2 + 1].item()
).item()
)
}
}
)
Expand All @@ -129,7 +123,10 @@ def intervention_func_combinator(
],
intervention_func_lengths: List[int],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
assert len(intervention_funcs) == len(intervention_func_lengths)
assert len(intervention_funcs) == len(intervention_func_lengths), (
f"Size mismatch between number of intervention functions ('{len(intervention_funcs)}')"
f"and number of intervention function lengths ('{len(intervention_func_lengths)}') "
)

total_length = sum(intervention_func_lengths)

Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def __init__(
intervened_params = ["beta_c", "gamma"]
static_parameter_interventions1 = param_value_objective(
param_name=[intervened_params[0]],
start_time=torch.tensor(10.0),
start_time=[torch.tensor(10.0)],
)
static_parameter_interventions2 = start_time_objective(
param_name=[intervened_params[1]],
param_value=torch.tensor([0.45]),
param_value=[torch.tensor([0.45])],
)
optkwargs_SEIRHD_paramtimeComb_maxQoI = {
"qoi": [lambda x: obs_max_qoi(x, ["I_state"])],
Expand Down

0 comments on commit 02ed976

Please sign in to comment.