We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting the following error when trying to sample with a dynamic_parameter_intervention:
sample
dynamic_parameter_intervention
ERROR:root: ############################### There was an exception in pyciemss Error occured in function: sample Function docs : Load a model from a file, compile it into a probabilistic program, and sample from it. Args: model_path_or_json: Union[str, Dict] - A path to a AMR model file or JSON containing a model in AMR form. end_time: float - The end time of the sampled simulation. logging_step_size: float - The step size to use for logging the trajectory. num_samples: int - The number of samples to draw from the model. solver_method: str - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details. - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. - By default we set the `start_time` to be 0. inferred_parameters: Optional[pyro.nn.PyroModule] - A Pyro module that contains the inferred parameters of the model. This is typically the result of `calibrate`. - If not provided, we will use the default values from the AMR model. static_state_interventions: Dict[float, Dict[str, Intervention]] - A dictionary of static interventions to apply to the model. - Each key is the time at which the intervention is applied. - Each value is a dictionary of the form {state_variable_name: intervention_assignment}. - Note that the `intervention_assignment` can be any type supported by :func:`~chirho.interventional.ops.intervene`, including functions. static_parameter_interventions: Dict[float, Dict[str, Intervention]] - A dictionary of static interventions to apply to the model. - Each key is the time at which the intervention is applied. - Each value is a dictionary of the form {parameter_name: intervention_assignment}. - Note that the `intervention_assignment` can be any type supported by :func:`~chirho.interventional.ops.intervene`, including functions. dynamic_state_interventions: Dict[ Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], Dict[str, Intervention] ] - A dictionary of dynamic interventions to apply to the model. - Each key is a function that takes in the current state of the model and returns a tensor. When this function crosses 0, the dynamic intervention is applied. - Each value is a dictionary of the form {state_variable_name: intervention_assignment}. - Note that the `intervention_assignment` can be any type supported by :func:`~chirho.interventional.ops.intervene`, including functions. dynamic_parameter_interventions: Dict[ Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], Dict[str, Intervention] ] - A dictionary of dynamic interventions to apply to the model. - Each key is a function that takes in the current state of the model and returns a tensor. When this function crosses 0, the dynamic intervention is applied. - Each value is a dictionary of the form {parameter_name: intervention_assignment}. - Note that the `intervention_assignment` can be any type supported by :func:`~chirho.interventional.ops.intervene`, including functions. alpha: float - Risk level for alpha-superquantile outputs in the results dictionary. Returns: result: Dict[str, torch.Tensor] - Dictionary of outputs with following attributes: - data: The samples from the model as a pandas DataFrame. - unprocessed_result: Dictionary of outputs from the model. - Each key is the name of a parameter or state variable in the model. - Each value is a tensor of shape (num_samples, num_timepoints) for state variables and (num_samples,) for parameters. - quantiles: The quantiles for ensemble score calculation as a pandas DataFrames. - risk: Dictionary with each key as the name of a state with a dictionary of risk estimates for each state at the final timepoint. - risk: alpha-superquantile risk estimate Superquantiles can be intuitively thought of as a tail expectation, or an average over a portion of worst-case outcomes. Given a distribution of a quantity of interest (QoI), the superquantile at level \alpha\in[0, 1] is the expected value of the largest 100(1 -\alpha)% realizations of the QoI. - qoi: Samples of quantity of interest (value of the state at the final timepoint) - schema: Visualization. (If visual_options is truthy) ################################ Traceback (most recent call last): File "/Users/altu809/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped result = function(*args, **kwargs) File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 332, in sample samples = pyro.infer.Predictive( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward return _predictive( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 127, in _predictive return _predictive_sequential( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 55, in _predictive_sequential {site: trace.nodes[site]["value"] for site in return_site_shapes} File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 55, in <dictcomp> {site: trace.nodes[site]["value"] for site in return_site_shapes} KeyError: 'parameter_intervention_value_p_cbeta_0' --------------------------------------------------------------------------- KeyError Traceback (most recent call last) Cell In[15], line 10 7 infection_threshold = make_var_threshold("I", torch.tensor(400.0)) 8 dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}} ---> 10 result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 11 dynamic_parameter_interventions=dynamic_parameter_interventions1, 12 solver_method="dopri5") 13 display(result["data"].head()) 15 # Plot the result File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs) 17 log_message = """ 18 ############################### 19 (...) 26 ################################ 27 """ 28 logging.exception(log_message, function.__name__, function.__doc__) ---> 29 raise e File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs) 8 try: 9 start_time = time.perf_counter() ---> 10 result = function(*args, **kwargs) 11 end_time = time.perf_counter() 12 logging.info( 13 "Elapsed time for %s: %f", function.__name__, end_time - start_time 14 ) File ~/Projects/pyciemss/pyciemss/interfaces.py:332, in sample(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, time_unit, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions, alpha) 320 compiled_noise_model(full_trajectory) 322 parallel = ( 323 False 324 if len( (...) 329 else True 330 ) --> 332 samples = pyro.infer.Predictive( 333 wrapped_model, 334 guide=inferred_parameters, 335 num_samples=num_samples, 336 parallel=parallel, 337 )() 339 risk_results = {} 340 for k, vals in samples.items(): File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(*input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], [] File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs) 263 return_sites = None if not return_sites else return_sites 264 posterior_samples = _predictive( 265 self.guide, 266 posterior_samples, (...) 271 model_kwargs=kwargs, 272 ) --> 273 return _predictive( 274 self.model, 275 posterior_samples, 276 self.num_samples, 277 return_sites=return_sites, 278 parallel=self.parallel, 279 model_args=args, 280 model_kwargs=kwargs, 281 ) File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:127, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs) 124 return_site_shapes["_RETURN"] = shape 126 if not parallel: --> 127 return _predictive_sequential( 128 model, 129 posterior_samples, 130 model_args, 131 model_kwargs, 132 num_samples, 133 return_site_shapes, 134 return_trace=False, 135 ) 137 trace = poutine.trace( 138 poutine.condition(vectorize(model), reshaped_samples) 139 ).get_trace(*model_args, **model_kwargs) 140 predictions = {} File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:55, in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace) 52 collected.append(trace) 53 else: 54 collected.append( ---> 55 {site: trace.nodes[site]["value"] for site in return_site_shapes} 56 ) 58 if return_trace: 59 return collected File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:55, in <dictcomp>(.0) 52 collected.append(trace) 53 else: 54 collected.append( ---> 55 {site: trace.nodes[site]["value"] for site in return_site_shapes} 56 ) 58 if return_trace: 59 return collected KeyError: 'parameter_intervention_value_p_cbeta_0'
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Getting the following error when trying to
sample
with adynamic_parameter_intervention
:The text was updated successfully, but these errors were encountered: