Skip to content

Commit

Permalink
adding template for start time and parameter value intervention
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Jun 26, 2024
1 parent 422232c commit 704d835
Show file tree
Hide file tree
Showing 2 changed files with 498 additions and 65 deletions.
512 changes: 452 additions & 60 deletions docs/source/interfaces.ipynb

Large diffs are not rendered by default.

51 changes: 46 additions & 5 deletions pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ def param_value_objective(
start_time: List[torch.Tensor],
param_value: List[Intervention] = [None],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
if len(param_value) < len(param_name) and param_value[0] is None:
param_size = len(param_name)
if len(param_value) < param_size and param_value[0] is None:
param_value = [None for _ in param_name]
for count in range(len(param_name)):
for count in range(param_size):
if param_value[count] is None:
if not callable(param_value[count]):
param_value[count] = lambda y: torch.tensor(y)
Expand All @@ -20,7 +21,7 @@ def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(len(param_name)):
for count in range(param_size):
if start_time[count].item() in static_parameter_interventions:
static_parameter_interventions[start_time[count].item()].update(
{param_name[count]: param_value[count](x[count].item())}
Expand All @@ -39,13 +40,16 @@ def intervention_generator(


def start_time_objective(
param_name: List[str], param_value: List[Intervention]
param_name: List[str],
param_value: List[Intervention],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
param_size = len(param_name)

def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(len(param_name)):
for count in range(param_size):
if x[count].item() in static_parameter_interventions:
static_parameter_interventions[x[count].item()].update(
{param_name[count]: param_value[count]}
Expand All @@ -59,6 +63,43 @@ def intervention_generator(
return intervention_generator


def start_time_param_value_objective(
param_name: List[str],
param_value: List[Intervention] = [None],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
param_size = len(param_name)
if len(param_value) < param_size and param_value[0] is None:
param_value = [None for _ in param_name]
for count in range(param_size):
if param_value[count] is None:
if not callable(param_value[count]):
param_value[count] = lambda y: torch.tensor(y)

def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
assert x.size()[0] == param_size*2
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(param_size):
if x[count * 2].item() in static_parameter_interventions:
static_parameter_interventions[x[count * 2].item()].update(
{param_name[count]: param_value[count](x[count * 2 + 1].item())}
)
else:
static_parameter_interventions.update(
{
x[count * 2].item(): {
param_name[count]: param_value[count](
x[count * 2 + 1].item()
)
}
}
)
return static_parameter_interventions

return intervention_generator


def combine_static_parameter_interventions(
interventions: List[Dict[torch.Tensor, Dict[str, Intervention]]]
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
Expand Down

0 comments on commit 704d835

Please sign in to comment.