Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError for dynamic parameter intervention #525

Closed
sabinala opened this issue Mar 8, 2024 · 0 comments
Closed

KeyError for dynamic parameter intervention #525

sabinala opened this issue Mar 8, 2024 · 0 comments

Comments

@sabinala
Copy link
Contributor

sabinala commented Mar 8, 2024

Getting the following error when trying to sample with a dynamic_parameter_intervention:

ERROR:root:
                ###############################

                There was an exception in pyciemss

                Error occured in function: sample

                Function docs : 
    Load a model from a file, compile it into a probabilistic program, and sample from it.

    Args:
        model_path_or_json: Union[str, Dict]
            - A path to a AMR model file or JSON containing a model in AMR form.
        end_time: float
            - The end time of the sampled simulation.
        logging_step_size: float
            - The step size to use for logging the trajectory.
        num_samples: int
            - The number of samples to draw from the model.
        solver_method: str
            - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
            - If performance is incredibly slow, we suggest using `euler` to debug.
              If using `euler` results in faster simulation, the issue is likely that the model is stiff.
        solver_options: Dict[str, Any]
            - Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
        start_time: float
            - The start time of the model. This is used to align the `start_state` from the
              AMR model with the simulation timepoints.
            - By default we set the `start_time` to be 0.
        inferred_parameters: Optional[pyro.nn.PyroModule]
            - A Pyro module that contains the inferred parameters of the model.
              This is typically the result of `calibrate`.
            - If not provided, we will use the default values from the AMR model.
        static_state_interventions: Dict[float, Dict[str, Intervention]]
            - A dictionary of static interventions to apply to the model.
            - Each key is the time at which the intervention is applied.
            - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        static_parameter_interventions: Dict[float, Dict[str, Intervention]]
            - A dictionary of static interventions to apply to the model.
            - Each key is the time at which the intervention is applied.
            - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        dynamic_state_interventions: Dict[
                                        Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                        Dict[str, Intervention]
                                        ]
            - A dictionary of dynamic interventions to apply to the model.
            - Each key is a function that takes in the current state of the model and returns a tensor.
              When this function crosses 0, the dynamic intervention is applied.
            - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        dynamic_parameter_interventions: Dict[
                                            Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                            Dict[str, Intervention]
                                            ]
            - A dictionary of dynamic interventions to apply to the model.
            - Each key is a function that takes in the current state of the model and returns a tensor.
              When this function crosses 0, the dynamic intervention is applied.
            - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        alpha: float
            - Risk level for alpha-superquantile outputs in the results dictionary.

    Returns:
        result: Dict[str, torch.Tensor]
            - Dictionary of outputs with following attributes:
                - data: The samples from the model as a pandas DataFrame.
                - unprocessed_result: Dictionary of outputs from the model.
                    - Each key is the name of a parameter or state variable in the model.
                    - Each value is a tensor of shape (num_samples, num_timepoints) for state variables
                    and (num_samples,) for parameters.
                - quantiles: The quantiles for ensemble score calculation as a pandas DataFrames.
                - risk: Dictionary with each key as the name of a state with
                a dictionary of risk estimates for each state at the final timepoint.
                    - risk: alpha-superquantile risk estimate
                    Superquantiles can be intuitively thought of as a tail expectation, or an average
                    over a portion of worst-case outcomes. Given a distribution of a
                    quantity of interest (QoI), the superquantile at level \alpha\in[0, 1] is
                    the expected value of the largest 100(1 -\alpha)% realizations of the QoI.
                    - qoi: Samples of quantity of interest (value of the state at the final timepoint)
                - schema: Visualization. (If visual_options is truthy)
    

                ################################
            
Traceback (most recent call last):
  File "/Users/altu809/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
    result = function(*args, **kwargs)
  File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 332, in sample
    samples = pyro.infer.Predictive(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
    return _predictive(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 127, in _predictive
    return _predictive_sequential(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 55, in _predictive_sequential
    {site: trace.nodes[site]["value"] for site in return_site_shapes}
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 55, in <dictcomp>
    {site: trace.nodes[site]["value"] for site in return_site_shapes}
KeyError: 'parameter_intervention_value_p_cbeta_0'
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[15], line 10
      7 infection_threshold = make_var_threshold("I", torch.tensor(400.0))
      8 dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
---> 10 result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
     11                          dynamic_parameter_interventions=dynamic_parameter_interventions1, 
     12                          solver_method="dopri5")
     13 display(result["data"].head())
     15 # Plot the result

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
     17 log_message = """
     18     ###############################
     19 
   (...)
     26     ################################
     27 """
     28 logging.exception(log_message, function.__name__, function.__doc__)
---> 29 raise e

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
      8 try:
      9     start_time = time.perf_counter()
---> 10     result = function(*args, **kwargs)
     11     end_time = time.perf_counter()
     12     logging.info(
     13         "Elapsed time for %s: %f", function.__name__, end_time - start_time
     14     )

File ~/Projects/pyciemss/pyciemss/interfaces.py:332, in sample(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, time_unit, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions, alpha)
    320         compiled_noise_model(full_trajectory)
    322 parallel = (
    323     False
    324     if len(
   (...)
    329     else True
    330 )
--> 332 samples = pyro.infer.Predictive(
    333     wrapped_model,
    334     guide=inferred_parameters,
    335     num_samples=num_samples,
    336     parallel=parallel,
    337 )()
    339 risk_results = {}
    340 for k, vals in samples.items():

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
    263     return_sites = None if not return_sites else return_sites
    264     posterior_samples = _predictive(
    265         self.guide,
    266         posterior_samples,
   (...)
    271         model_kwargs=kwargs,
    272     )
--> 273 return _predictive(
    274     self.model,
    275     posterior_samples,
    276     self.num_samples,
    277     return_sites=return_sites,
    278     parallel=self.parallel,
    279     model_args=args,
    280     model_kwargs=kwargs,
    281 )

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:127, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
    124     return_site_shapes["_RETURN"] = shape
    126 if not parallel:
--> 127     return _predictive_sequential(
    128         model,
    129         posterior_samples,
    130         model_args,
    131         model_kwargs,
    132         num_samples,
    133         return_site_shapes,
    134         return_trace=False,
    135     )
    137 trace = poutine.trace(
    138     poutine.condition(vectorize(model), reshaped_samples)
    139 ).get_trace(*model_args, **model_kwargs)
    140 predictions = {}

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:55, in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
     52         collected.append(trace)
     53     else:
     54         collected.append(
---> 55             {site: trace.nodes[site]["value"] for site in return_site_shapes}
     56         )
     58 if return_trace:
     59     return collected

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:55, in <dictcomp>(.0)
     52         collected.append(trace)
     53     else:
     54         collected.append(
---> 55             {site: trace.nodes[site]["value"] for site in return_site_shapes}
     56         )
     58 if return_trace:
     59     return collected

KeyError: 'parameter_intervention_value_p_cbeta_0'
@sabinala sabinala closed this as completed Mar 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant