Skip to content

Commit

Permalink
Adding support for ensemble output format (#570)
Browse files Browse the repository at this point in the history
* Add ensemble output formatting

* Adding support for cumulative states

* Updated CDC reformatting

* using deepcopy to avoid overwriting results

* linting

* Lint

* fix time unit issue in cdc formatting

* updated defaults

* Lint

* Update result_processing.py

* update for time_unit

* Update result_processing.py

* Update result_processing.py

* Update result_processing.py

* lint

* update tests to match output from `convert_to_output_format`

* Update result_processing.py

* Update result_processing.py

* Update result_processing.py

* Update result_processing.py

* Update result_processing.py

* Update interfaces.py

* update alpha_qs type to Sized

* updating alpha_qs type to List

* Update result_processing.py

* Update result_processing.py

* lint

* fix warning

* Update interfaces.ipynb

* Dropping all other states/observables except the ones in `solution_string_mapping`

* Updating logging_times to include start_time and end_time in ensemble_sample

* Updating how to set train_end_point

* lint

* Update interfaces.ipynb

* Update interfaces.ipynb

* fixed failing tests (#571)

* Update interfaces.ipynb

* Fixing tests with expanded timespan

* Lint

* resolved viz test time misalignment (#573)

* Removing commented out code; update documentation for ensemble_sample

* replaced DEFAULT_ALPHA_QS with import

* lint

---------

Co-authored-by: Sam Witty <[email protected]>
  • Loading branch information
anirban-chaudhuri and SamWitty authored Apr 25, 2024
1 parent 4c65d26 commit a18eb59
Show file tree
Hide file tree
Showing 7 changed files with 830 additions and 175 deletions.
704 changes: 560 additions & 144 deletions docs/source/interfaces.ipynb

Large diffs are not rendered by default.

238 changes: 228 additions & 10 deletions pyciemss/integration_utils/result_processing.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,63 @@
from typing import Any, Dict, Iterable, Mapping, Optional, Union
import warnings
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch

from pyciemss.visuals import plots

DEFAULT_ALPHA_QS = [
0.01,
0.025,
0.05,
0.1,
0.15,
0.2,
0.25,
0.3,
0.35,
0.4,
0.45,
0.5,
0.55,
0.6,
0.65,
0.7,
0.75,
0.8,
0.85,
0.9,
0.95,
0.975,
0.99,
]


def prepare_interchange_dictionary(
samples: Dict[str, torch.Tensor],
time_unit: Optional[str] = None,
timepoints: Optional[Iterable[float]] = None,
timepoints: Optional[torch.Tensor] = None,
visual_options: Union[None, bool, Dict[str, Any]] = None,
ensemble_quantiles: bool = False,
alpha_qs: Optional[List[float]] = DEFAULT_ALPHA_QS,
stacking_order: str = "timepoints",
) -> Dict[str, Any]:
samples = {k: (v.squeeze() if len(v.shape) > 2 else v) for k, v in samples.items()}

processed_samples = convert_to_output_format(
samples, time_unit=time_unit, timepoints=timepoints
processed_samples, quantile_results = convert_to_output_format(
samples,
time_unit=time_unit,
timepoints=timepoints,
ensemble_quantiles=ensemble_quantiles,
alpha_qs=alpha_qs,
stacking_order=stacking_order,
)

result = {"data": processed_samples, "unprocessed_result": samples}
if ensemble_quantiles:
result["ensemble_quantiles"] = quantile_results

if visual_options:
visual_options = {} if visual_options is True else visual_options
Expand All @@ -33,8 +71,11 @@ def convert_to_output_format(
samples: Dict[str, torch.Tensor],
*,
time_unit: Optional[str] = None,
timepoints: Optional[Iterable[float]] = None,
) -> pd.DataFrame:
timepoints: Optional[torch.Tensor] = None,
ensemble_quantiles: bool = False,
alpha_qs: Optional[List[float]] = None,
stacking_order: str = "timepoints",
) -> Tuple[pd.DataFrame, Union[pd.DataFrame, None]]:
"""
Convert the samples from the Pyro model to a DataFrame in the TA4 requested format.
"""
Expand All @@ -55,7 +96,7 @@ def convert_to_output_format(
if time_unit is not None and timepoints is None:
raise ValueError("`timepoints` must be supplied when a `time_unit` is supplied")

pyciemss_results: Dict[str, Dict[str, torch.Tensor]] = {
pyciemss_results: Dict[str, Dict[str, np.ndarray]] = {
"parameters": {},
"states": {},
}
Expand Down Expand Up @@ -86,9 +127,10 @@ def convert_to_output_format(
}

if timepoints is not None:
timepoints = [*timepoints]
label = "timepoint_unknown" if time_unit is None else f"timepoint_{time_unit}"
output[label] = np.array(float(timepoints[v]) for v in output["timepoint_id"])
output[label] = np.array(
float(timepoints[v].item()) for v in output["timepoint_id"]
)

# Parameters
output = {
Expand Down Expand Up @@ -116,7 +158,183 @@ def convert_to_output_format(
):
result = set_intervention_values(result, name, values, intervention_times)

return result
if ensemble_quantiles:
result_quantiles = make_quantiles(
pyciemss_results,
alpha_qs=alpha_qs,
time_unit=time_unit,
timepoints=timepoints,
stacking_order=stacking_order,
)
else:
result_quantiles = None

return result, result_quantiles


def make_quantiles(
pyciemss_results: Dict[str, Dict[str, np.ndarray]],
*,
alpha_qs: Optional[List[float]] = None,
time_unit: Optional[str] = None,
timepoints: Optional[torch.Tensor] = None,
stacking_order: str = "timepoints",
) -> Union[pd.DataFrame, None]:
"""Make quantiles for each timepoint"""
_, num_timepoints = next(iter(pyciemss_results["states"].values())).shape
key_list = ["timepoint_id", "output", "type", "quantile", "value"]
q: Dict[str, List] = {k: [] for k in key_list}
if alpha_qs is not None:
num_quantiles = len(alpha_qs)

# Solution (state variables)
for k, v in pyciemss_results["states"].items():
q_vals = np.quantile(v, alpha_qs, axis=0)
k = k.replace("_sol", "")
if stacking_order == "timepoints":
# Keeping timepoints together
q["timepoint_id"].extend(
list(np.repeat(np.array(range(num_timepoints)), num_quantiles))
)
q["output"].extend([k] * num_timepoints * num_quantiles)
q["type"].extend(["quantile"] * num_timepoints * num_quantiles)
q["quantile"].extend(list(np.tile(alpha_qs, num_timepoints)))
q["value"].extend(
list(
np.squeeze(
q_vals.T.reshape((num_timepoints * num_quantiles, 1))
)
)
)
elif stacking_order == "quantiles":
# Keeping quantiles together
q["timepoint_id"].extend(
list(np.tile(np.array(range(num_timepoints)), num_quantiles))
)
q["output"].extend([k] * num_timepoints * num_quantiles)
q["type"].extend(["quantile"] * num_timepoints * num_quantiles)
q["quantile"].extend(list(np.repeat(alpha_qs, num_timepoints)))
q["value"].extend(
list(
np.squeeze(q_vals.reshape((num_timepoints * num_quantiles, 1)))
)
)
else:
raise Exception("Incorrect input for stacking_order.")

result_quantiles = pd.DataFrame(q)
if timepoints is not None:
all_timepoints = result_quantiles["timepoint_id"].map(
lambda v: timepoints[v].item()
)
result_quantiles = result_quantiles.assign(
**{f"number_{time_unit}": all_timepoints}
)
result_quantiles = result_quantiles[
[
"timepoint_id",
f"number_{time_unit}",
"output",
"type",
"quantile",
"value",
]
]
else:
result_quantiles = None
return result_quantiles


def cdc_format(
q_ensemble_input: pd.DataFrame,
solution_string_mapping: Dict[str, str],
*,
time_unit: Optional[str] = None,
forecast_start_date: Optional[str] = None,
location: Optional[str] = None,
drop_column_names: List[str] = [
"timepoint_id",
"output",
],
train_end_point: Optional[float] = None,
) -> pd.DataFrame:
"""
Reformat the quantiles pandas dataframe file to CDC ensemble forecast format
Note that solution_string_mapping maps name of states/observables in the dictionary key to the dictionary value
and also drops any states/observables not available in the dictionary keys.
forecast_start_date is the date of last observed data.
"""
q_ensemble_data = deepcopy(q_ensemble_input)
if time_unit != "days" or time_unit is None:
warnings.warn(
"cdc_format only works for time_unit=days"
"time_unit will default to days and overwrite previous time_unit."
)
q_ensemble_data.rename(columns={"number_None": "number_days"}, inplace=True)
if "number_days" not in q_ensemble_data:
raise ValueError("time_unit can only support days")
time_unit = "days"

if train_end_point is None:
q_ensemble_data["Forecast_Backcast"] = "Forecast"
number_data_days = 0.0
else:
q_ensemble_data["Forecast_Backcast"] = np.where(
q_ensemble_data[f"number_{time_unit}"] > train_end_point,
"Forecast",
"Backcast",
)
# Number of days for which data is available
number_data_days = max(
q_ensemble_data[
q_ensemble_data["Forecast_Backcast"].str.contains("Backcast")
][f"number_{time_unit}"]
)
drop_column_names.extend(["Forecast_Backcast"])
# Subtracting number of backast days from number_days
q_ensemble_data[f"number_{time_unit}"] = (
q_ensemble_data[f"number_{time_unit}"] - number_data_days
)
# Drop rows that are backcasting
q_ensemble_data = q_ensemble_data[
~q_ensemble_data["Forecast_Backcast"].str.contains("Backcast")
]
# Changing name of state according to user provided strings
if solution_string_mapping:
# Drop rows that are not present in the solution_string_mapping keys
q_ensemble_data = q_ensemble_data[
q_ensemble_data["output"].str.contains(
"|".join(solution_string_mapping.keys())
)
]
for k, v in solution_string_mapping.items():
q_ensemble_data["output"] = q_ensemble_data["output"].replace(k, v)

# Creating target column
q_ensemble_data["target"] = (
q_ensemble_data[f"number_{time_unit}"].astype("string")
+ " days ahead "
# + q_ensemble_data["inc_cum"]
+ " "
+ q_ensemble_data["output"]
)

# Add dates
if forecast_start_date:
q_ensemble_data["forecast_date"] = pd.to_datetime(
forecast_start_date, format="%Y-%m-%d", errors="ignore"
)
q_ensemble_data["target_end_date"] = q_ensemble_data["forecast_date"].combine(
q_ensemble_data[f"number_{time_unit}"],
lambda x, y: x + pd.DateOffset(days=int(y)),
)
# Add location column
if location:
q_ensemble_data["location"] = location
# Dropping columns specified by user
if drop_column_names:
q_ensemble_data = q_ensemble_data.drop(columns=drop_column_names)
return q_ensemble_data


# --- Intervention weaving utilities ----
Expand Down
40 changes: 29 additions & 11 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper
from pyciemss.integration_utils.interface_checks import check_solver
from pyciemss.integration_utils.observation import compile_noise_model, load_data
from pyciemss.integration_utils.result_processing import prepare_interchange_dictionary
from pyciemss.integration_utils.result_processing import (
DEFAULT_ALPHA_QS,
prepare_interchange_dictionary,
)
from pyciemss.interruptions import (
DynamicParameterIntervention,
ParameterInterventionTracer,
Expand Down Expand Up @@ -52,7 +55,9 @@ def ensemble_sample(
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
time_unit: Optional[str] = None,
):
alpha_qs: Optional[List[float]] = DEFAULT_ALPHA_QS,
stacking_order: str = "timepoints",
) -> Dict[str, Any]:
"""
Load a collection of models from files, compile them into an ensemble probabilistic program,
and sample from the ensemble.
Expand Down Expand Up @@ -97,13 +102,22 @@ def ensemble_sample(
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
alpha_qs: Optional[List[float]]
- The quantiles required for estimating weighted interval score to test ensemble forecasting accuracy.
stacking_order: Optional[str]
- The stacking order requested for the ensemble quantiles to keep the selected quantity together for each state.
- Options: "timepoints" or "quantiles"
Returns:
result: Dict[str, torch.Tensor]
- Dictionary of outputs from the model.
- Each key is the name of a parameter or state variable in the model.
- Each value is a tensor of shape (num_samples, num_timepoints) for state variables
result: Dict[str, Any]
- Dictionary of outputs with following attributes:
- data: The samples from the model as a pandas DataFrame.
- unprocessed_result: Dictionary of outputs from the model.
- Each key is the name of a parameter or state variable in the model.
- Each value is a tensor of shape (num_samples, num_timepoints) for state variables
and (num_samples,) for parameters.
- ensemble_quantiles: The quantiles for ensemble score calculation as a pandas DataFrames.
- schema: Visualization. (If visual_options is truthy)
"""
check_solver(solver_method, solver_options)

Expand All @@ -116,7 +130,7 @@ def ensemble_sample(
)

logging_times = torch.arange(
start_time + logging_step_size, end_time, logging_step_size
start_time, end_time + logging_step_size, logging_step_size
)

# Check that num_samples is a positive integer
Expand Down Expand Up @@ -151,7 +165,12 @@ def wrapped_model():
)()

return prepare_interchange_dictionary(
samples, timepoints=logging_times, time_unit=time_unit
samples,
timepoints=logging_times,
time_unit=time_unit,
ensemble_quantiles=True,
alpha_qs=alpha_qs,
stacking_order=stacking_order,
)


Expand Down Expand Up @@ -410,14 +429,13 @@ def sample(
- Risk level for alpha-superquantile outputs in the results dictionary.
Returns:
result: Dict[str, torch.Tensor]
result: Dict[str, Any]
- Dictionary of outputs with following attributes:
- data: The samples from the model as a pandas DataFrame.
- unprocessed_result: Dictionary of outputs from the model.
- Each key is the name of a parameter or state variable in the model.
- Each value is a tensor of shape (num_samples, num_timepoints) for state variables
and (num_samples,) for parameters.
- quantiles: The quantiles for ensemble score calculation as a pandas DataFrames.
- risk: Dictionary with each key as the name of a state with
a dictionary of risk estimates for each state at the final timepoint.
- risk: alpha-superquantile risk estimate
Expand All @@ -435,7 +453,7 @@ def sample(
model = CompiledDynamics.load(model_path_or_json)

logging_times = torch.arange(
start_time + logging_step_size, end_time, logging_step_size
start_time, end_time + logging_step_size, logging_step_size
)

# Check that num_samples is a positive integer
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def check_result_sizes(
assert isinstance(v, torch.Tensor)

num_timesteps = len(
torch.arange(start_time + logging_step_size, end_time, logging_step_size)
torch.arange(start_time, end_time + logging_step_size, logging_step_size)
)
if v.ndim == 2 and k == "model_weights":
assert v.shape[0] == num_samples
Expand Down
Loading

0 comments on commit a18eb59

Please sign in to comment.